from zipimport_ import zipimport
#import zipimport

import importlib

import contextlib
import os
import py_compile
import shutil
from test import test_support
import time
import unittest
import zipfile

example_code = 'attr = None'

created_paths = set(['_top_level',
                     os.path.join('_pkg', '__init__'),
                     os.path.join('_pkg', 'submodule'),
                     os.path.join('_pkg', '_subpkg', '__init__'),
                     os.path.join('_pkg', '_subpkg', 'submodule')
                    ])

@contextlib.contextmanager
def temp_zipfile(source=True, bytecode=True):
    """Create a temporary zip file for testing.

    Clears zipimport._zip_directory_cache.

    """
    zipimport._zip_directory_cache.clear()
    zip_path = test_support.TESTFN + '.zip'
    bytecode_suffix = 'c' if __debug__ else 'o'
    zip_file = zipfile.ZipFile(zip_path, 'w')
    try:
        for path in created_paths:
            if os.sep in path:
                directory = os.path.split(path)[0]
                if not os.path.exists(directory):
                    os.makedirs(directory)
            code_path = path + '.py'
            with open(code_path, 'w') as temp_file:
                temp_file.write(example_code)
            if source:
                zip_file.write(code_path)
            if bytecode:
                py_compile.compile(code_path, doraise=True)
                zip_file.write(code_path + bytecode_suffix)
        zip_file.close()
        yield os.path.abspath(zip_path)
    finally:
        zip_file.close()
        for path in created_paths:
            if os.sep in path:
                directory = os.path.split(path)[0]
                if os.path.exists(directory):
                    shutil.rmtree(directory)
            else:
                for suffix in ('.py', '.py' + bytecode_suffix):
                    test_support.unlink(path + suffix)
        test_support.unlink(zip_path)


class ZipImportErrorTests(unittest.TestCase):

    """Test ZipImportError."""

    def test_inheritance(self):
        # Should inherit from ImportError.
        self.assert_(issubclass(zipimport.ZipImportError, ImportError))


class ZipImportCreation(unittest.TestCase):

    """Test the creation of a zipimport.zipimporter instance."""

    def test_nonzip(self):
        # ZipImportError should be raised if a non-zip file is specified.
            with open(test_support.TESTFN, 'w') as test_file:
                test_file.write("# Test file for zipimport.")
            try:
                self.assertRaises(zipimport.ZipImportError,
                        zipimport.zipimporter, test_support.TESTFN)
            finally:
                test_support.unlink(test_support.TESTFN)

    def test_root(self):
        self.assertRaises(zipimport.ZipImportError, zipimport.zipimporter,
                            os.sep)


    def test_direct_path(self):
        # A zipfile should return an instance of zipimporter.
        with temp_zipfile() as zip_path:
            zip_importer = zipimport.zipimporter(zip_path)
            self.assert_(isinstance(zip_importer, zipimport.zipimporter))
            self.assertEqual(zip_importer.archive, zip_path)
            self.assertEqual(zip_importer.prefix, '')
            self.assert_(zip_path in zipimport._zip_directory_cache)

    def test_pkg_path(self):
        # Thanks to __path__, need to be able to work off of a path with a zip
        # file at the front and a path for the rest.
        with temp_zipfile() as zip_path:
            prefix = '_pkg'
            path = os.path.join(zip_path, prefix)
            zip_importer = zipimport.zipimporter(path)
            self.assert_(isinstance(zip_importer, zipimport.zipimporter))
            self.assertEqual(zip_importer.archive, zip_path)
            self.assertEqual(zip_importer.prefix, prefix)
            self.assert_(zip_path in zipimport._zip_directory_cache)

    def test_zip_directory_cache(self):
        # Test that _zip_directory_cache is set properly.
        # Using a package entry to test using a hard example.
        with temp_zipfile(bytecode=False) as zip_path:
            importer = zipimport.zipimporter(os.path.join(zip_path, '_pkg'))
            self.assert_(zip_path in zipimport._zip_directory_cache)
            file_set = set(zipimport._zip_directory_cache[zip_path].iterkeys())
            compare_set = set(path + '.py' for path in created_paths)
            self.assertEqual(file_set, compare_set)


class FindModule(unittest.TestCase):

    """Test the finding of modules."""

    def look_for(self, module_name):
        """Return what zipimporter.find_module() returns."""
        for args in [[True, True], [True, False], [False, True]]:
            with temp_zipfile(*args) as zip_path:
                if '.' in module_name:
                    zip_path = os.path.join(zip_path,
                                            module_name.rsplit('.', 1)[0])
                importer = zipimport.zipimporter(zip_path)
                yield (importer, importer.find_module(module_name))

    def test_top_level(self):
        # Should be able to find a top-level module if there is source,
        # bytecode, or both.
        for importer, result in self.look_for('_top_level'):
            self.assert_(importer is result)
        for importer, result in self.look_for('_bogus'):
            self.assert_(result is None)

    def test_pkg(self):
        # Finding a package should work.
        for importer, result in self.look_for('_pkg'):
            self.assert_(importer is result)

    def test_submodule(self):
        # A submodule in a package should work.
        for importer, result in self.look_for('_pkg.submodule'):
            self.assert_(importer is result)
        for importer, result in self.look_for('_bogus.submodule'):
            self.assert_(result is None)
        for importer, result in self.look_for('_pkg.bogus'):
            self.assert_(result is None)


class GetData(unittest.TestCase):

    """Test zipimporter.get_data()."""

    def test_text(self):
        # Reading some text should work.
        with temp_zipfile(bytecode=False) as zip_path:
            importer = zipimport.zipimporter(zip_path)
            self.assertEqual(importer.get_data('_top_level.py'), example_code)

    def test_absolute_path(self):
        with temp_zipfile(bytecode=False) as zip_path:
            importer = zipimport.zipimporter(zip_path)
            path = os.path.join(os.path.abspath(zip_path), '_top_level.py')
            self.assertEqual(importer.get_data(path), example_code)

    # XXX Test that file reading is done in binary mode.


class IsPackage(unittest.TestCase):

    """Test zipimporter.is_package()."""

    def test_pkg(self):
        # A package should always return True.
        for args in [[True, True], [True, False], [False, True]]:
            with temp_zipfile(*args) as zip_path:
                importer = zipimport.zipimporter(zip_path)
                importer.find_module('_pkg')
                self.assert_(importer.is_package('_pkg'))
                importer.find_module('_top_level')
                self.assert_(not importer.is_package('_top_level'))
                pkg_path = os.path.join(zip_path, '_pkg')
                importer = zipimport.zipimporter(pkg_path)
                importer.find_module('_pkg.submodule')
                self.assert_(not importer.is_package('_pkg.submodule'))


class GetSource(unittest.TestCase):

    """Test zipimporter.get_source()."""

    def test_get_top_level_source(self):
        with temp_zipfile(bytecode=False) as zip_path:
            importer = zipimport.zipimporter(zip_path)
            importer.find_module('_top_level')
            self.assertEqual(example_code, importer.get_source('_top_level'))

    def test_no_top_level_source(self):
        with temp_zipfile(source=False) as zip_path:
            importer = zipimport.zipimporter(zip_path)
            importer.find_module('_top_level')
            self.assertEqual(importer.get_source('_top_level'), None)

    def test_pkg_source(self):
        with temp_zipfile(bytecode=False) as zip_path:
            importer = zipimport.zipimporter(zip_path)
            importer.find_module('_pkg')
            self.assertEqual(example_code, importer.get_source('_pkg'))

    def test_no_pkg_source(self):
        with temp_zipfile(source=False) as zip_path:
            importer = zipimport.zipimporter(zip_path)
            importer.find_module('_pkg')
            self.assertEqual(importer.get_source('_pkg'), None)


class GetCode(unittest.TestCase):

    """Test zipimporter.get_code()."""

    def test_mod_time(self):
        with temp_zipfile() as zip_path:
            actual_mtime = int(os.stat('_top_level.py').st_mtime)
            importer = zipimport.zipimporter(zip_path)
            zip_mtime = importer.mod_time('_top_level')
            # There can end up being a difference of 1 (probably from rounding)
            #  but that is acceptable.
            self.assert_(zip_mtime == actual_mtime or
                         (zip_mtime + 1) == actual_mtime)

    def verify_code(self, archive_path, module):
        importer = zipimport.zipimporter(archive_path)
        code_obj = importer.get_code(module)
        ns = {}
        exec code_obj in ns
        self.assert_('attr' in ns)
        self.assertEqual(ns['attr'], None)

    def test_top_level(self):
        for args in [[True, True], [True, False], [False, True]]:
            with temp_zipfile(*args) as zip_path:
                self.verify_code(zip_path, '_top_level')

    def test_pkg(self):
        for args in [[True, True], [True, False], [False, True]]:
            with temp_zipfile(*args) as zip_path:
                self.verify_code(zip_path, '_pkg')


class LoadModule(unittest.TestCase):

    """Test zipimporter.load_module()."""

    def test_top_level(self):
        with temp_zipfile() as zip_path:
            importer = zipimport.zipimporter(zip_path)
            module = importer.load_module('_top_level')
        self.assertEqual(module.__name__, '_top_level')
        self.assertEqual(module.__loader__, importer)
        self.assertEqual(module.__file__,
                         os.path.join(zip_path, '_top_level.pyc'))
        self.assert_(hasattr(module, 'attr'))
        self.assert_(not hasattr(module, '__path__'))
        self.assertEqual(getattr(module, 'attr'), None)

    def test_pkg(self):
        with temp_zipfile() as zip_path:
            importer = zipimport.zipimporter(zip_path)
            module = importer.load_module('_pkg')
            self.assertEqual(module.__name__, '_pkg')
            self.assertEqual(module.__loader__, importer)
            self.assertEqual(module.__file__,
                             os.path.join(zip_path, '_pkg', '__init__.pyc'))
            self.assertEqual(module.__path__, [os.path.join(zip_path, '_pkg')])
            self.assert_(hasattr(module, 'attr'))
            self.assertEqual(getattr(module, 'attr'), None)
            importer = zipimport.zipimporter(os.path.join(zip_path, '_pkg'))
            module = importer.load_module('_pkg._subpkg')
        self.assertEqual(module.__name__, '_pkg._subpkg')
        self.assertEqual(module.__path__,
                            [os.path.join(zip_path, '_pkg', '_subpkg')])


def test_main():
    test_support.run_unittest(ZipImportErrorTests,
                                ZipImportCreation,
                                FindModule,
                                GetData,
                                IsPackage,
                                GetSource,
                                GetCode,
                                LoadModule
                             )


if __name__ == '__main__':
    test_main()
