from __future__ import with_statement

import importlib

from tests import mock_importlib
from tests.ext_help import find_ext_location
from tests.py_help import TestPyPycPackages

import imp
import os
import os.path
import py_compile
import sys
import tempfile
from test import test_support
import unittest
import warnings


class PyFileImporterTests(TestPyPycPackages):

    """Test the PyFileImporterTests (and thus also FileImporter).

    No need to check for searching any deeper than a package (e.g., don't need
    to look for a sub-module or sub-package) since the import machinery will
    always have an importer for the containing package directory.

    """

    def setUp(self):
        TestPyPycPackages.setUp(self, faked_names=False)
        self.importer = importlib.PyFileImporter(self.directory)
        self.importer._loader = lambda name, path, is_pkg: path

    def tearDown(self):
        TestPyPycPackages.tearDown(self)

    def test_py_top_level(self):
        # A top-level source module should work.
        test_support.unlink(self.pyc_path)
        found = self.importer.find_module(self.module_name)
        self.assertEqual(found, self.py_path)

    def test_failure(self):
        os.unlink(self.py_path)
        os.unlink(self.pyc_path)
        found = self.importer.find_module(self.module_name)
        self.assert_(found is None)

    def test_pyc_top_level(self):
        # A top-level bytecode module should work.
        test_support.unlink(self.py_path)
        found = self.importer.find_module(self.module_name)
        self.assertEqual(found, self.pyc_path)

    def test_py_package(self):
        # A Python source package should be found.
        # TestPyPycPackages, by default, does not compile the __init__ file.
        if self.pkg_init_path[-1] not in ('c', 'o'):
            assert not os.path.exists(self.pkg_init_path + 'c')
            assert not os.path.exists(self.pkg_init_path + 'o')
        else:
            assert False
        found = self.importer.find_module(self.pkg_name)
        self.assertEqual(found, self.pkg_init_path)

    def test_pyc_package(self):
        # A bytecode package should be found.
        py_compile.compile(self.pkg_init_path)
        test_support.unlink(self.pkg_init_path)
        expected = self.pkg_init_path + ('c' if __debug__ else 'o')
        found = self.importer.find_module(self.pkg_name)
        self.assertEqual(found, expected)

    def test_file_type_order(self):
        # The order of the search should be preserved with source files being
        # found first.
        assert os.path.exists(self.py_path)
        assert self.py_ext in self.importer._suffixes
        assert os.path.exists(self.pyc_path)
        assert self.pyc_ext in self.importer._suffixes
        py_suffixes = importlib.suffix_list(imp.PY_SOURCE)
        found = self.importer.find_module(self.module_name)
        for suffix in py_suffixes:
            if found.endswith(suffix):
                break
        else:
            self.fail("Python source files not searched for before bytecode "
                        "files")

    def test_missing__init__warning(self):
        # An ImportWarning should be raised if a directory matches a module
        # name but no __init__ file exists.
        test_support.unlink(self.pkg_init_path)
        with test_support.catch_warning() as w:
            warnings.simplefilter('always', ImportWarning)
            found = self.importer.find_module(self.pkg_name)
            self.assert_(found is None)
            self.assert_(issubclass(w.category, ImportWarning))
            self.assert_(str(w.message).endswith("missing __init__"))

    def test_package_before_module(self):
        # A package should always be found before a module with the same name.
        # This should not vary based on whether one is source and another is
        # bytecode, etc.
        module_path = os.path.join(self.directory, self.pkg_name + self.py_ext)
        with open(module_path, 'w') as module_file:
            module_file.write('# Testing packcage over module import.')
        try:
            # Source/source.
            found = self.importer.find_module(self.pkg_name)
            self.assert_('__init__' in found)
            # Source/bytecode.
            py_compile.compile(self.pkg_init_path)
            os.unlink(self.pkg_init_path)
            found = self.importer.find_module(self.pkg_name)
            self.assert_('__init__' in found)
            # Bytecode/bytecode.
            py_compile.compile(module_path)
            os.unlink(module_path)
            found = self.importer.find_module(self.pkg_name)
            self.assert_('__init__' in found)
            # Bytecode/source.
            # tearDown will remove the package __init__ file.
            with open(self.pkg_init_path, 'w') as pkg_file:
                pkg_file.write('# testing package/module import preference.')
            os.unlink(self.pkg_init_path + ('c' if __debug__ else 'o'))
            found = self.importer.find_module(self.pkg_name)
            self.assert_('__init__' in found)
        finally:
            test_support.unlink(module_path)
            test_support.unlink(os.path.join(self.directory,
                                self.pkg_name + self.pyc_ext))

    def test_module_case_sensitivity(self):
        # Case-sensitivity should always matter as long as PYTHONCASEOK is not
        # set.
        name_len = len(self.top_level_module_name)
        bad_case_name = (self.top_level_module_name[:name_len//2].upper() +
                            self.top_level_module_name[name_len//2:].lower())
        env_guard = test_support.EnvironmentVarGuard()
        env_guard.unset('PYTHONCASEOK')
        with env_guard:
            self.failUnless(not self.importer.find_module(bad_case_name))
        if sys.platform not in ('win32', 'mac', 'darwin', 'cygwin', 'os2emx',
                'riscos'):
            return
        env_guard = test_support.EnvironmentVarGuard()
        env_guard.set('PYTHONCASEOK', '1')
        with env_guard:
            assert os.environ['PYTHONCASEOK']
            self.failUnless(self.importer.find_module(bad_case_name))

    def test_package_case_sensitivity(self):
        # Case-sensitivity should always matter as long as PYTHONCASEOK is not
        # set.
        name_len = len(self.pkg_name)
        bad_case_name = (self.pkg_name[:name_len//2].upper() +
                            self.pkg_name[name_len//2:].lower())
        bad_init_name = os.path.join(self.directory, self.pkg_name,
                                        '__INit__.py')
        env_guard = test_support.EnvironmentVarGuard()
        env_guard.unset('PYTHONCASEOK')
        with env_guard:
            self.failUnless(not self.importer.find_module(bad_case_name))
            os.unlink(self.pkg_init_path)
            with open(bad_init_name, 'w') as init_file:
                init_file.write('# Test case-sensitivity of imports.')
            self.failUnless(not self.importer.find_module(self.pkg_name))
        if sys.platform not in ('win32', 'mac', 'darwin', 'cygwin', 'os2emx',
                'riscos'):
            return
        os.unlink(bad_init_name)
        with open(self.pkg_init_path, 'w') as init_file:
            init_file.write('# Used for testing import.')
        env_guard = test_support.EnvironmentVarGuard()
        env_guard.set('PYTHONCASEOK', '1')
        with env_guard:
            assert os.environ['PYTHONCASEOK']
            self.failUnless(self.importer.find_module(bad_case_name))
            with open(bad_init_name, 'w') as init_file:
                init_file.write('# Used to test case-insensitivity of import.')
            self.failUnless(self.importer.find_module(self.pkg_name))


class ExtensionFileImporterTests(unittest.TestCase):

    def test_basic(self):
        # Finding an extension module should work.
        search_for = 'datetime'
        extension_dir, filename = find_ext_location(search_for)
        importer = importlib.ExtensionFileImporter(extension_dir)
        importer._loader = lambda name, path, is_pkg: path
        found = importer.find_module(search_for)
        self.assert_(found is not None)
        self.assert_(search_for in found)

    def test_failure(self):
        search_for = 'asdfsdfasdffd'
        importer = importlib.ExtensionFileImporter('.')
        found = importer.find_module(search_for)
        self.assert_(found is None)


def test_main():
    test_support.run_unittest(PyFileImporterTests, ExtensionFileImporterTests)


if __name__ == '__main__':
    test_main()
