#!/usr/bin/env python

# This unit test doesn't use any of the debugger code. It is meant solely
# to test the connection classes.

import os
import sys
import socket
import time
import thread
import unittest

from test import test_support
from socket import gaierror

# Global vars
__addr__ = 'localhost:8002'
MAXTRIES = 100
TESTFN = 'device'

sys.path.append("..")
from mconnection import (MConnectionServerTCP, MConnectionClientTCP,
                         MConnectionSerial, MConnectionServerFIFO,
                         MConnectionClientFIFO, ConnectionFailed,
                         ReadError, WriteError)

# Try to connect the client to addr either until we've tried MAXTRIES
# times or until it succeeds.
def repeatedConnect(client, addr):
    for i in range(MAXTRIES):
        try:
            client.connect(addr)
            if client.connected: break
        except ConnectionFailed:
                pass

class TestTCPConnections(unittest.TestCase):
    def setUp(self):
        self.server = MConnectionServerTCP()
        self.client = MConnectionClientTCP()

    def testClientConnectToServer(self):
        thread.start_new_thread(repeatedConnect, (self.client, __addr__))
        self.server.connect(__addr__)

	self.server.disconnect()

    def testClientConnectAndRead(self):
        thread.start_new_thread(repeatedConnect, (self.client,__addr__))
        self.server.connect(__addr__)

        self.server.write("good")
        line = self.client.readline()
        self.assertEqual("good", line, "Could not read from server")
        self.client.write('success')
        line = self.server.readline()
        self.assertEqual('success\n', line, 'Could not read from client')

    def testDisconnectDisconnected(self):
        s = MConnectionServerTCP()

        s.disconnect()
        s.disconnect()

    def testReadline(self):
        thread.start_new_thread(repeatedConnect, (self.client,__addr__))
        self.server.connect(__addr__)

        self.client.write('good')
        line = self.server.readline()
        self.assertEquals('good\n', line, 'Could not read first line.')

	self.server.disconnect()

    def testErrorAddressAlreadyInUse(self):
        thread.start_new_thread(repeatedConnect, (self.client, __addr__))
        self.server.connect(__addr__)

        # Set up second server on same port and do not reuse the addr
        s = MConnectionServerTCP()
        self.assertRaises(ConnectionFailed, s.connect, __addr__, False)

    def testInvalidServerAddress(self):
        addr = 'fff.209320909xcmnm2iu3-=0-0-z.,x.,091209:2990'
        self.assertRaises(ConnectionFailed, self.server.connect, addr)

    def testConnectionRefused(self):
        self.assertRaises(ConnectionFailed, self.client.connect, __addr__)

    def testInvalidAddressPortPair(self):
        addr = 'localhost 8000'
        self.assertRaises(ConnectionFailed, self.server.connect, addr)

    def testServerReadError(self):
        thread.start_new_thread(self.server.connect, (__addr__,))

        while not self.server._sock:
            time.sleep(0.1)
            
        repeatedConnect(self.client, __addr__)

        # Wait to make _absolutely_ sure that the client has connected
        while not self.server.output:
            time.sleep(0.1)
        self.client.disconnect()
        self.assertRaises(ReadError, self.server.readline)

        self.server.disconnect()

        thread.start_new_thread(self.client.connect, (__addr__,))
        self.server.connect(__addr__)

        self.server.disconnect()
        self.assertRaises(ReadError, self.client.readline)

    def tearDown(self):
        self.server.disconnect()
        self.client.disconnect()
    
class TestSerialConnections(unittest.TestCase):
    """ This test just uses a file instead of a serial device, which
    on *nix systems is just files anyway.
    """
    def setUp(self):
        self.server = MConnectionSerial()
        self.client = MConnectionSerial()
        fd = open(TESTFN, "wr+")
        fd.close()
        self.server.connect(TESTFN)
        self.client.connect(TESTFN)
        
    def testClientToServerConnect(self):
        self.client.disconnect()
        self.server.disconnect()

    def testClientWriteRead(self):
        self.client.write('success!')
        line = self.server.readline()
        self.assertEquals('success!\n', line, 'Could not read from client.')

        # Unfortunately the text file doesn't erase what we've written like a
        # device of stream, so we have to close the the file and re-open it.
        self.server.disconnect()
        self.server.connect(TESTFN)
        self.server.write('great!')
        line = self.client.readline()
        self.assertEquals('great!\n', line, 'Could not read from server.')

    def testDisconnectDisconnected(self):
        self.server.disconnect()

    def testReadline(self):
        self.client.write('success!\nNext line.')
        self.client.disconnect()
        line = self.server.readline()
        self.assertEquals('success!\n', line, 'Could not read first line')
        line = self.server.readline()
        self.assertEquals('Next line.\n', line, 'Could not read second line.')
        line = self.server.readline()
        self.assertEquals('', line, 'Could not read third line.')

    def testInvalidFilename(self):
        client = MConnectionSerial()
        self.assertRaises(ConnectionFailed, client.connect,
                          '/dev/pleasepleasepleasedontexit')

    def tearDown(self):
        self.server.disconnect()
        self.client.disconnect()
        os.remove(TESTFN)

class TestFIFOConnections(unittest.TestCase):
    def setUp(self):
        self.server = MConnectionServerFIFO()
        self.client = MConnectionClientFIFO()

    def testConnect(self):
        thread.start_new_thread(self.client.connect, ('test_file',))
        self.server.connect('test_file')

    def testReadWrite(self):
        thread.start_new_thread(self.client.connect, ('test_file',))
        self.server.connect('test_file')

        # Server write, client read
        self.server.write('Tim The Enchanter!\n')

        # Wait for the thread to catch up
        while not self.client.input:
            pass
        line = self.client.readline()
        self.assertEquals('Tim The Enchanter!\n', line)

        # Client write, server read
        self.client.write('received\n')
        line = self.server.readline()
        self.assertEquals('received\n', line)

    def testMultipleDisconnect(self):
        self.client.disconnect()
        self.server.disconnect()

    def testReadError(self):
        thread.start_new_thread(self.client.connect, ('test_file',))
        self.server.connect('test_file')

        while not self.client.input:
            pass

        self.client.disconnect()
        self.assertRaises(ReadError, self.server.readline)

        self.client.connect('test_file')

        self.server.disconnect()
        self.assertRaises(ReadError, self.client.readline)

    def testWriteError(self):
        thread.start_new_thread(self.client.connect, ('test_file',))
        self.server.connect('test_file')

        while not self.client.input:
            pass

        self.client.disconnect()
        self.assertRaises(WriteError, self.server.write, 'spam\n')

        self.client.connect('test_file')

        self.server.disconnect()
        self.assertRaises(WriteError, self.client.write, 'Ni!\n')

    def testInvalidPipe(self):
        self.assertRaises(ConnectionFailed,self.client.connect, 'invalid')
        os.unlink('invalid0')

    def tearDown(self):
        self.client.disconnect()
        self.server.disconnect()

        
def test_main():
    test_support.run_unittest(TestTCPConnections, TestSerialConnections,
                              TestFIFOConnections)
    
if __name__ == '__main__':
    test_main()
