#!/usr/bin/env python3.0

"""Unit tests for io.py."""

import os
import shutil
import tempfile
import unittest

import io # Local


class IOTestCase(unittest.TestCase):

    def setUp(self):
        self.tfn = tempfile.mktemp()

    def tearDown(self):
        if os.path.exists(self.tfn):
            os.remove(self.tfn)

    def test_unbuffered(self):
        sample1 = bytes("hello ")
        sample2 = bytes("world\n")
        f = io.open(self.tfn, "wb", 0)
        try:
            f.write(sample1)
            f.write(sample2)
        finally:
            f.close()

        f = open(self.tfn) # Classic open!
        try:
            data = f.read()
        finally:
            f.close()
        self.assertEquals(bytes(data), sample1+sample2)

        f = io.open(self.tfn, "rb", 0)
        try:
            data = f.read(1)
            self.assertEquals(data, sample1[:1])
            self.assertEquals(f.tell(), 1)
            f.seek(0)
            self.assertEquals(f.tell(), 0)
            data = f.read(len(sample1))
            self.assertEquals(data, sample1)
            data += f.read(100)
            self.assertEquals(data, sample1+sample2)
        finally:
            f.close()

    def test_buffered_read(self):
        sample1 = bytes("hello ")
        sample2 = bytes("world\n")
        for bufsize in 1, 2, 3, 4, 5, 6, 7, 8, 16, 8*1024:
            f = io.open(self.tfn, "wb", bufsize)
            try:
                f.write(sample1)
                f.write(sample2)
            finally:
                f.close()

            f = open(self.tfn) # Classic open!
            try:
                data = f.read()
            finally:
                f.close()
            self.assertEquals(bytes(data), sample1+sample2,
                              "%r != %r, bufsize=%s" % (bytes(data),
                                                        sample1+sample2,
                                                        bufsize))

            f = io.open(self.tfn, "rb", bufsize)
            try:
                data = f.read(1)
                self.assertEquals(data, sample1[:1])
                self.assertEquals(f.tell(), 1)
                f.seek(0)
                self.assertEquals(f.tell(), 0)
                data = f.read(len(sample1))
                self.assertEquals(data, sample1)
                data += f.read(100)
                self.assertEquals(data, sample1+sample2)
            finally:
                f.close()

            os.remove(self.tfn)
        

if __name__ == "__main__":
    unittest.main()
