from __future__ import with_statement

import marshal
import os.path
import py_compile
import shutil
import sys
import tempfile
import unittest


class TestPyPycFiles(unittest.TestCase):
    
    """Base class to help in generating a fresh source and bytecode file.

    Structure of created files:
    * directory/
        + py_path [module_name + py_ext]
        + pyc_path [module_name +pyc_ext]

    """
    
    def setUp(self, faked_names=True):
        """Generate the path to a temporary file to test with.

        If faked_names is true then all names are non-standard compared to
        normal Python files.

        """
        # Set needed info for file paths.
        if faked_names:
            self.module_name = '<test_module>'
            self.py_ext = '.source'
            self.pyc_ext = '.bytecode'
        else:
            self.module_name = '_test_module'
            self.py_ext = '.py'
            self.pyc_ext = '.pyc' if __debug__ else '.pyo'
        try:
            del sys.modules[self.module_name]
        except KeyError:
            pass
        self.directory = tempfile.gettempdir()
        self.test_attr = ('test_attr', None)
        self.py_path = os.path.join(self.directory,
                                    self.module_name+self.py_ext)
        self.pyc_path = os.path.join(self.directory,
                                        self.module_name+self.pyc_ext)
        # Create source file.
        self.source = '%s = %r' % self.test_attr
        with open(self.py_path, 'w') as py_file:
            py_file.write(self.source)
        # Create bytecode file.
        py_compile.compile(self.py_path, self.pyc_path, 'exec')
        code = compile(self.source, self.pyc_path, 'exec')
        self.bytecode = marshal.dumps(code)
        sys.path.insert(0, self.directory)

    def tearDown(self):
        """If the temporary path was used, make sure to clean up."""
        if self.directory in sys.path:
            sys.path.remove(self.directory)
        if os.path.exists(self.py_path):
            os.remove(self.py_path)
        if os.path.exists(self.pyc_path):
            os.remove(self.pyc_path)
 
    def verify_module(self, module, file_path=None):
        """Verify that the module is correct."""
        if file_path:
            self.failUnlessEqual(module.__name__, self.module_name)
            self.failUnlessEqual(module.__file__, file_path)
        self.failUnless(hasattr(module, self.test_attr[0]))
        self.failUnlessEqual(getattr(module, self.test_attr[0]),
                                self.test_attr[1])


class TestPyPycPackages(TestPyPycFiles):

    """Create a testing package.

    Structure of created files (on top of files created in superclasses):
    * self.directory/
        + top_level_module_path [top_level_module_name + py_ext]
        + pkg_path/ [pkg_name]
            - pkg_init_path [ '__init__' + py_ext]
            - pkg_module_path [module_name + py_ext]
            - sub_pkg_path/ [sub_pkg_name]
                * sub_pkg_init_path ['__init__' + py_ext]
                * sub_pkg_module_path [module_name + py_ext]

    """

    def setUp(self, faked_names=True):
        TestPyPycFiles.setUp(self, faked_names)
        self.top_level_module_name = 'top_level_' + self.module_name
        self.top_level_module_path = os.path.join(self.directory,
                                        self.top_level_module_name+self.py_ext)
        with open(self.top_level_module_path, 'w') as top_level_file:
            top_level_file.write(self.source)
        if faked_names:
            self.pkg_name = '<test_pkg>'
        else:
            self.pkg_name = '_test_pkg'
        self.pkg_path = os.path.join(self.directory, self.pkg_name)
        try:
            os.mkdir(self.pkg_path)
        except OSError:
            self.tearDown()
            os.mkdir(self.pkg_path)
        self.pkg_init_path = os.path.join(self.pkg_path,
                                            '__init__'+self.py_ext)
        self.pkg_init_pyc_path = os.path.join(self.pkg_path,
                                                '__init__'+self.pyc_ext)
        with open(self.pkg_init_path, 'w') as pkg_file:
            pkg_file.write(self.source)
        self.pkg_module_name = '.'.join([self.pkg_name, self.module_name])
        self.pkg_module_path = os.path.join(self.pkg_path,
                                        self.module_name+self.py_ext)
        with open(self.pkg_module_path, 'w') as module_file:
            module_file.write(self.source)
        self.sub_pkg_tail_name = 'sub_pkg'
        self.sub_pkg_name = '.'.join([self.pkg_name, self.sub_pkg_tail_name])
        self.sub_pkg_path = os.path.join(self.pkg_path, self.sub_pkg_tail_name)
        os.mkdir(self.sub_pkg_path)
        self.sub_pkg_init_path = os.path.join(self.sub_pkg_path,
                                                '__init__'+self.py_ext)
        with open(self.sub_pkg_init_path, 'w') as subpkg_file:
            subpkg_file.write(self.source)
        self.sub_pkg_module_name = '.'.join([self.sub_pkg_name,
                                            self.module_name])
        self.sub_pkg_module_path = os.path.join(self.sub_pkg_path,
                                                self.module_name+self.py_ext)
        with open(self.sub_pkg_module_path, 'w') as submodule_file:
            submodule_file.write(self.source)

    def tearDown(self):
        TestPyPycFiles.tearDown(self)
        os.remove(self.top_level_module_path)
        pyc_path = (os.path.splitext(self.top_level_module_path)[0] +
                    self.pyc_ext)
        if os.path.exists(pyc_path):
            os.remove(pyc_path)
        shutil.rmtree(self.pkg_path)

    def verify_package(self, module, actual_name=None):
        self.failUnless(hasattr(module, self.test_attr[0]))
        self.failUnlessEqual(getattr(module, self.test_attr[0]),
                                self.test_attr[1])
        if module.__name__ == self.pkg_name:
            self.failUnless(module.__file__ in
                            [self.pkg_init_path, self.pkg_init_pyc_path])
            self.failUnlessEqual(module.__path__, [self.pkg_path])
            # Module in top-level package.
            if actual_name and self.pkg_module_name in actual_name:
                self.failUnless(hasattr(module, self.module_name))
                sub_module = getattr(module, self.module_name)
                self.failUnlessEqual(sub_module.__name__, self.pkg_module_name)
                self.failUnlessEqual(sub_module.__file__, self.pkg_module_path)
                self.verify_module(sub_module)
            # Package within top-level package.
            if actual_name and self.sub_pkg_name in actual_name:
                self.failUnless(hasattr(module, self.sub_pkg_tail_name))
                sub_pkg = getattr(module, self.sub_pkg_tail_name)
                self.failUnlessEqual(sub_pkg.__name__, self.sub_pkg_name)
                self.failUnlessEqual(sub_pkg.__file__, self.sub_pkg_init_path)
                self.failUnlessEqual(sub_pkg.__path__, [self.sub_pkg_path])
                self.verify_module(sub_pkg)
                if actual_name == self.sub_pkg_module_name:
                    self.failUnless(hasattr(sub_pkg, self.module_name))
                    sub_module = getattr(sub_pkg, self.module_name)
                    self.failUnlessEqual(sub_module.__name__,
                                            self.sub_pkg_module_name)
                    self.failUnlessEqual(sub_module.__file__,
                                            self.sub_pkg_module_path)
                    self.verify_module(sub_module)
        if module.__name__ == self.pkg_module_name:
            self.failUnlessEqual(module.__file__, self.pkg_module_path)


