import importlib

from tests import mock_importlib
from tests.ext_help import find_ext_location
from tests.py_help import TestPyPycPackages

import imp
import os
import py_compile
import sys
from test import test_support
import unittest


class LoaderBasics(unittest.TestCase):

    def reload_test(self, loader, name):
        # If a module already exists in sys.modules then it should be reused
        # and be re-initialized.
        if name not in sys.modules:
            loader.load_module(name)
        cached = sys.modules[name]
        loaded = loader.load_module(name)
        self.assert_(cached is loaded)

    def basic_test(self, loader, name, path):
        # A module, after being loaded, should appear in sys.modules.
        if name in sys.modules:
            del sys.modules[name]
        loaded = loader.load_module(name)
        self.assert_(loaded is sys.modules[name])
        self.assertEqual(loaded.__name__, name)
        self.assert_(loaded.__file__.startswith(path),
                        "%s does not start with %s" % (loaded.__file__, path))
        self.assertEqual(loaded.__loader__, loader)

    def ImportError_on_bad_name(self, loader, bad_name, extra_methods=[]):
        to_test = ['load_module', 'is_package', 'get_code', 'get_source']
        to_test += extra_methods
        for method_name in to_test:
            method = getattr(loader, method_name)
            self.assertRaises(ImportError, method, bad_name)


class ExtensionFileLoaderTests(LoaderBasics):

    testing_module = '_testcapi'
    testing_location = find_ext_location(testing_module)
    testing_path = os.path.join(*testing_location)

    def setUp(self):
        self.loader = importlib._ExtensionFileLoader(self.testing_module,
                                                     self.testing_path, False)

    def test_basic(self):
        # Should be able to import an extension module.
        self.basic_test(self.loader, self.testing_module, self.testing_path)

    def test_reload(self):
        # A module that is already in sys.modules should be reused and be
        # re-initialized.
        self.reload_test(self.loader, self.testing_module)

    def test_ImportError_on_bad_name(self):
        # ImportError should be raised when a method is called with a module
        # name it cannot handle.
        self.ImportError_on_bad_name(self.loader, self.testing_module + 'asdf')

    def test_is_package(self):
        # Should always be False.
        self.assert_(not self.loader.is_package(self.testing_module))

    def test_get_code(self):
        # Should always be None.
        self.assert_(not self.loader.get_code(self.testing_module))

    def test_get_source(self):
        # Should always be None.
        self.assert_(not self.loader.get_source(self.testing_module))


class BasicPyFileLoaderTests(LoaderBasics, TestPyPycPackages):

    """Very basic tests for the source loader."""

    def setUp(self):
        TestPyPycPackages.setUp(self, faked_names=False)

    def test_basic(self):
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.basic_test(loader, self.module_name, self.py_path)

    def test_pkg_basic(self):
        loader = importlib._PyFileLoader(self.pkg_name, self.pkg_init_path,
                                            True)
        self.basic_test(loader, self.pkg_name, self.pkg_init_path)

    def test_reload(self):
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.reload_test(loader, self.module_name)

    def test_ImportError_on_bad_name(self):
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        extra_methods = ['write_bytecode', 'get_bytecode', 'mod_time']
        self.ImportError_on_bad_name(loader, self.module_name + 'sdfasdf',
                                        extra_methods)

    def test_no_stale_module_on_failure(self):
        # A failure during loading should not leave a partially initialized
        # module in sys.modules.
        def fail_loading(loader, name, *args):
            sys.modules[name] = 'Should not exist'
            raise ImportError('initial failed load')
        if self.module_name in sys.modules:
            del sys.modules[self.module_name]
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        loader._handler = fail_loading
        self.assertRaises(ImportError, loader.load_module, self.module_name)
        self.assert_(self.module_name not in sys.modules)

    def test_file_deletion_post_init(self):
        # Loading from source (with no bytecode), deleting the source (because
        # bytecode was generated), and then loading again should work.
        test_support.unlink(self.pyc_path)
        if self.module_name in sys.modules:
            del sys.modules[self.module_name]
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        found = loader.load_module(self.module_name)
        self.verify_package(found, self.module_name)
        del sys.modules[self.module_name]
        os.unlink(self.py_path)
        assert os.path.exists(self.pyc_path)
        found = loader.load_module(self.module_name)
        self.verify_package(found, self.module_name)

    def test_fails_with_no_files_post_init(self):
        # The loader should fail gracefully if the files it would have used to
        # load the module have been removed.
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        test_support.unlink(self.py_path)
        test_support.unlink(self.pyc_path)
        self.assertRaises(ImportError, loader.load_module, self.module_name)

    def test_reload_failure(self):
        # If a module is reloaded and something happens during loading, the
        # module should be left in place and not cleared out.
        test_support.unlink(self.pyc_path)
        with open(self.py_path, 'w') as source_file:
            source_file.write('x = 42/0')
        mock_module = mock_importlib.MockModule(self.module_name)
        mock_module.__file__ = 'blah'
        if hasattr(mock_module, '__loader__'):
            delattr(mock_module, '__loader__')
        sys.modules[self.module_name] = mock_module
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.assertRaises(ZeroDivisionError, loader.load_module,
                            self.module_name)
        self.assert_(self.module_name in sys.modules)
        self.assertEquals(mock_module.__file__, 'blah')
        self.assert_(not hasattr(mock_module, '__loader__'))


def log_call(instance, method_name):
    """Log a method call."""
    method = getattr(instance, method_name)
    if not hasattr(instance, '_log'):
        instance._log = []
    def logger(*args, **kwargs):
        instance._log.append(method_name)
        return method(*args, **kwargs)
    setattr(instance, method_name, logger)


class PyFileLoaderLoadingTests(TestPyPycPackages):

    """Test that the source loader uses the proper file.

    Make sure all tests are run both against a top-level module and a package
    where appropriate.

    """

    def setUp(self):
        TestPyPycPackages.setUp(self, faked_names=False)

    def test_pyc_over_py(self):
        # If a bytecode file is good, don't even bother with the source
        # (top-level or package).
        for name, path, is_pkg in [(self.module_name, self.py_path, False),
                                    (self.pkg_name, self.pkg_init_path, True)]:
            if name in sys.modules:
                del sys.modules[name]
            py_compile.compile(path)
            loader = importlib._PyFileLoader(name, path, is_pkg)
            log_call(loader, 'get_source')
            found = loader.load_module(name)
            self.assert_('get_source' not in loader._log)
            self.verify_package(found, name)

    def test_only_good_pyc(self):
        # Should be able to load even if only bytecode is available (top-level
        # or package).
        to_test = [(self.module_name, self.py_path, self.pyc_path, False),
                   (self.pkg_name, self.pkg_init_path, self.pkg_init_pyc_path,
                       True)]
        for name, source_path, bytecode_path, is_pkg in to_test:
            if name in sys.modules:
                del sys.modules[name]
            py_compile.compile(source_path)
            os.unlink(source_path)
            loader = importlib._PyFileLoader(name, bytecode_path, is_pkg)
            log_call(loader, 'mod_time')
            log_call(loader, 'get_source')
            found = loader.load_module(name)
            self.assert_('mod_time' not in loader._log)
            self.assert_('get_source' not in loader._log)
            self.verify_package(found, name)

    def test_only_py(self):
        # Having only source should be fine (top-level or package).
        to_test = [(self.module_name, self.py_path, self.pyc_path, False),
                   (self.pkg_name, self.pkg_init_path, self.pkg_init_pyc_path,
                       True)]
        for name, source_path, bytecode_path, is_pkg in to_test:
            if name in sys.modules:
                del sys.modules[name]
            test_support.unlink(bytecode_path)
            loader = importlib._PyFileLoader(name, source_path, is_pkg)
            log_call(loader, 'get_bytecode')
            log_call(loader, 'write_bytecode')
            found = loader.load_module(name)
            self.assert_('get_bytecode' not in loader._log)
            self.assert_(os.path.exists(bytecode_path))
            self.verify_package(found, name)
            # Make sure generated bytecode is good.
            del sys.modules[name]
            os.unlink(source_path)
            assert os.path.exists(bytecode_path)
            loader = importlib._PyFileLoader(name, bytecode_path, is_pkg)
            log_call(loader, 'mod_time')
            log_call(loader, 'get_bytecode')
            found = loader.load_module(name)
            self.assert_('mod_time' not in loader._log)
            self.assert_('get_bytecode' in loader._log)
            self.verify_package(found, name)

    def test_stale_pyc(self):
        # If a bytecode file when compared to its source then regenerate the
        # bytecode.
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        log_call(loader, 'write_bytecode')
        with open(self.pyc_path, 'rb') as bytecode_file:
            data = bytecode_file.read()
        timestamp = importlib._r_long(data[4:8])
        with open(self.pyc_path, 'wb') as bytecode_file:
            bytecode_file.write(data[:4])
            bytecode_file.write(importlib._w_long(timestamp-1))
            bytecode_file.write(data[8:])
        found = loader.load_module(self.module_name)
        self.assert_('write_bytecode' in loader._log)
        self.verify_package(found, self.module_name)
        source_mtime = os.stat(self.py_path).st_mtime
        bytecode_mtime = os.stat(self.pyc_path).st_mtime
        self.assert_(bytecode_mtime >= source_mtime)
        del sys.modules[self.module_name]
        loader = importlib._PyFileLoader(self.module_name, self.pyc_path,
                                            False)
        log_call(loader, 'get_bytecode')
        log_call(loader, 'get_source')
        found = loader.load_module(self.module_name)
        self.assert_('get_bytecode' in loader._log)
        self.assert_('get_source' not in loader._log)
        self.verify_package(found, self.module_name)

    def test_bad_magic(self):
        # If the magic cookie for bytecode is bad then raise an exception (no
        # source), or regenerate the bytecode.
        def change_magic():
            with open(self.pyc_path, 'rb') as bytecode_file:
                data = bytecode_file.read()
            assert imp.get_magic() != '0' * 4
            with open(self.pyc_path, 'wb') as bytecode_file:
                bytecode_file.write('0' * 4)
                bytecode_file.write(data[4:])
        # With source.
        change_magic()
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        log_call(loader, 'write_bytecode')
        found = loader.load_module(self.module_name)
        self.assert_('write_bytecode' in loader._log)
        self.verify_package(found, self.module_name)
        source_mtime = os.stat(self.py_path).st_mtime
        bytecode_mtime = os.stat(self.pyc_path).st_mtime
        self.assert_(source_mtime <= bytecode_mtime)
        del sys.modules[self.module_name]
        loader = importlib._PyFileLoader(self.module_name, self.pyc_path,
                                            False)
        log_call(loader, 'get_bytecode')
        log_call(loader, 'get_source')
        found = loader.load_module(self.module_name)
        self.assert_('get_bytecode' in loader._log)
        self.assert_('get_source' not in loader._log)
        self.verify_package(found, self.module_name)
        # Without source.
        del sys.modules[self.module_name]
        os.unlink(self.py_path)
        change_magic()
        loader = importlib._PyFileLoader(self.module_name, self.pyc_path,
                                            False)
        self.assertRaises(ImportError, loader.load_module, self.module_name)

    def test_malformed_bytecode(self):
        # Invalid bytecode triggers an exception, source or not.
        source_mtime = int(os.stat(self.py_path).st_mtime)
        with open(self.pyc_path, 'wb') as bytecode_file:
            bytecode_file.write(imp.get_magic())
            bytecode_file.write(importlib._w_long(source_mtime))
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.assertRaises(Exception, loader.load_module, self.module_name)


class PEP302PyFileInterface(TestPyPycPackages):

    """Test the optional extensions from PEP 302."""

    def setUp(self):
        TestPyPycPackages.setUp(self, faked_names=False)

    def test_get_source(self):
        # Return the source when available, None if there is at least bytecode,
        # and raise ImportError if there is no change of loading the module.
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        source = loader.get_source(self.module_name)
        self.assertEqual(source, self.source)
        os.unlink(self.py_path)
        self.assert_(loader.get_source(self.module_name) is None)
        os.unlink(self.pyc_path)
        self.assertRaises(ImportError, loader.get_source, self.module_name)

    def test_get_data(self):
        loader = importlib._PyFileLoader(self.module_name, self.pyc_path,
                                            False)
        data = loader.get_data(self.pyc_path)
        with open(self.pyc_path, 'rb') as bytecode_file:
            original = bytecode_file.read()
        self.assertEqual(data, original)

    def test_is_package(self):
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.assert_(not loader.is_package(self.module_name))
        loader = importlib._PyFileLoader(self.pkg_name, self.pkg_init_path,
                                            True)
        self.assert_(loader.is_package(self.pkg_name))

    def get_code_test(self, loader):
        attr_name, attr_value = self.test_attr
        code_object = loader.get_code(loader._name)
        ns = {}
        exec code_object in ns
        self.assert_(attr_name in ns)
        self.assertEqual(ns[attr_name], attr_value)

    def test_get_code(self):
        # Source and bytecode.
        assert os.path.exists(self.py_path) and os.path.exists(self.pyc_path)
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.get_code_test(loader)
        # Source only.
        os.unlink(self.pyc_path)
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.get_code_test(loader)
        # Bytecode only.
        py_compile.compile(self.py_path, doraise=True)
        os.unlink(self.py_path)
        loader = importlib._PyFileLoader(self.module_name, self.pyc_path,
                                            False)
        self.get_code_test(loader)
        # Bad magic number in bytecode.
        with open(self.pyc_path, 'rb') as bytecode_file:
            data = bytecode_file.read()
        with open(self.pyc_path, 'wb') as bytecode_file:
            bytecode_file.write('0' * 4)
            bytecode_file.write('0' * 4)
            bytecode_file.write(data[8:])
        loader = importlib._PyFileLoader(self.module_name, self.pyc_path,
                                            False)
        self.assertRaises(ImportError, loader.get_code, self.module_name)
        # Nothing available.
        os.unlink(self.pyc_path)
        assert not os.path.exists(self.py_path)
        assert not os.path.exists(self.pyc_path)
        loader = importlib._PyFileLoader(self.module_name, self.py_path, False)
        self.assertRaises(ImportError, loader.get_code, self.module_name)


def test_main():
    test_support.run_unittest(ExtensionFileLoaderTests, BasicPyFileLoaderTests,
                                PyFileLoaderLoadingTests,
                                PEP302PyFileInterface)


if __name__ == '__main__':
    test_main()
