import controlled_importlib
import importlib

from tests import mock_importlib
from tests.ext_help import find_ext_location
from tests.py_help import TestPyPycPackages

from contextlib import contextmanager, nested
import os
import StringIO
import sys
import unittest
from test import test_support


class DummyImporterLoader(object):

    def find_module(self, name, path=None):
        return True

    def load_module(self, name):
        return True

class DummyWhitelist(controlled_importlib.Whitelister, DummyImporterLoader):

    pass


class WhitelistTests(unittest.TestCase):

    """Test the general whitelisting mechanism."""

    def test_whitelist_module(self):
        # A direct module name should be allowed.
        whitelist = 'ok_mod'
        imp_load = DummyWhitelist([whitelist])
        self.failUnless(imp_load.find_module(whitelist) is True)
        self.failUnless(imp_load.find_module('fdssdf') is None)
        self.failUnless(imp_load.load_module(whitelist) is True)
        self.failUnlessRaises(ImportError, imp_load.load_module, 'asdfadsf')

    def test_whitelist_list(self):
        # When the whitelist is a list it should be able to properly whitelist
        # no matter where a module is listed.
        whitelist1 = 'A'
        whitelist2 = 'B'
        imp_load = DummyWhitelist([whitelist1, whitelist2])
        for whitelist in (whitelist1, whitelist2):
            self.failUnless(imp_load.find_module(whitelist) is True)
            self.failUnless(imp_load.load_module(whitelist) is True)

    def test_block_partial_name(self):
        # A module that happens to be a prefix of a whitelisted module or has a
        # whitelisted module as a prefix should still be blocked.
        whitelist = 'mod'
        imp_load = DummyWhitelist([whitelist])
        # Module has a whitelisted module as a prefix.
        self.failUnless(imp_load.find_module(whitelist+'2') is None)
        self.failUnlessRaises(ImportError, imp_load.load_module, whitelist+'2')
        # Module is a prefix of a whitelisted module.
        self.failUnless(imp_load.find_module(whitelist[:-1]) is None)
        self.failUnlessRaises(ImportError, imp_load.load_module,
                whitelist[:-1])

    def test_package(self):
        # Whitelisting a package does not automatically allow the submodules.
        whitelist = 'pkg'
        imp_load = DummyWhitelist([whitelist])
        mod = whitelist + '.' + 'mod'
        self.failUnless(imp_load.find_module(mod) is None)
        self.failUnlessRaises(ImportError, imp_load.load_module, mod)


class WhitelistBuiltinTests(unittest.TestCase):

    """Test the whitelisting support for built-in modules."""

    def setUp(self):
        self.whitelist = sys.builtin_module_names[0]
        self.blacklist = sys.builtin_module_names[1]

    def test_whitelist(self):
        # Only modules on the whitelist should be allowed to be imported.
        # Everything else should return None.
        imp_load = controlled_importlib.WhitelistBuiltin([self.whitelist])
        # Importer
        self.failUnless(imp_load.find_module(self.whitelist) is not None)
        self.failUnless(imp_load.find_module(self.blacklist) is None)
        # Loader
        self.failUnless(imp_load.load_module(self.whitelist))
        self.failUnlessRaises(ImportError, imp_load.load_module,
                                self.blacklist)


class WhitelistFrozenTests(unittest.TestCase):

    """Test whitelisting of frozen modules."""

    def setUp(self):
        sys.stdout = StringIO.StringIO()
        self.whitelist = '__phello__'
        self.blacklist = ('__hello__', '__phello__.spam')

    def tearDown(self):
        sys.stdout = sys.__stdout__

    def test_whitelist(self):
        imp_load = controlled_importlib.WhitelistFrozen([self.whitelist])
        self.failUnless(imp_load.find_module(self.whitelist) is not None)
        self.failUnless(imp_load.load_module(self.whitelist))
        for blacklist in self.blacklist:
            self.failUnless(imp_load.find_module(blacklist) is None)
            self.failUnlessRaises(ImportError, imp_load.load_module, blacklist)


class WhitelistExtensionsTests(unittest.TestCase):

    """Test the whitelisting of extension modules."""

    def test_whitelist(self):
        whitelist = 'time'
        blacklist = 'datetime'
        assert (find_ext_location(whitelist)[0] ==
                find_ext_location(blacklist)[0])
        directory = find_ext_location(whitelist)[0]
        importer = controlled_importlib.WhitelistExtensionImporter([whitelist],
                                                                    directory)
        result = importer.find_module(whitelist)
        self.assert_(result is not None)
        result = importer.find_module(blacklist)
        self.assert_(result is None)




@contextmanager
def mutate_sys_modules(module, name):
    """Temporarily mutate sys.modules with a new module."""
    try:
        old_module = sys.modules.get(name)
        sys.modules[name] = module
        yield
    finally:
        if old_module:
            sys.modules[name] = old_module
        else:
            del sys.modules[name]

@contextmanager
def remove_from_sys_modules(*modules):
    """Temporarily remove modules from sys.modules."""
    try:
        cached = []
        for name in modules:
            if name in sys.modules:
                cached.append(sys.modules[name])
                del sys.modules[name]
        yield
    finally:
        for name, module in zip(modules, cached):
            sys.modules[name] = module

@contextmanager
def temp_setattr(obj, attribute, value):
    """Temporarily set an attribute on an object."""
    try:
        if hasattr(obj, attribute):
            old_value = getattr(obj, attribute)
        setattr(obj, attribute, value)
        yield value
    finally:
        setattr(obj, attribute, old_value)


class ControlledImportMethodTests(unittest.TestCase):

    """Test explicit methods of ControlledImport."""

    def setUp(self):
        self.import_ = controlled_importlib.ControlledImport([], [], [])

    def tearDown(self):
        sys.path_hooks = []
 
    def test_module_from_cache(self):
        # Importing of module names with a leading dot should not occur.
        module = 'module'
        module_name = '.blocked'
        assert module_name.startswith('.')
        with mutate_sys_modules(module, module_name):
            self.failUnlessRaises(ImportError, self.import_.module_from_cache,
                                    module_name)

    def test_post_import(self):
        # Any __loader__ attribute should be indiscriminately removed.
        module = mock_importlib.MockModule()
        self.failUnless(self.import_.post_import(module) is module)
        module.__loader__ = None
        stripped_module = self.import_.post_import(module)
        self.failUnless(stripped_module is module)
        self.failUnless(not hasattr(stripped_module, '__loader__'))


class ControlledImportUsageTests(TestPyPycPackages):

    """Make sure that usage of ControlledImport works properly."""

    def setUp(self):
        """Create .py and .pyc files for testing purposes."""
        super(ControlledImportUsageTests, self).setUp(faked_names=False)

    def tearDown(self):
        super(ControlledImportUsageTests, self).tearDown()
        sys.path_hooks = []
    
    def test_block_dot_modules(self):
        # Modules with a leading dot should not be imported.
        import_ = controlled_importlib.ControlledImport()
        module = mock_importlib.MockModule()
        module_name = '.block'
        with mutate_sys_modules(module, module_name):
            self.failUnlessRaises(ImportError, import_, module_name, level=0)

    def test_builtin_whitelisting(self):
        whitelist = sys.builtin_module_names[0]
        blacklist = sys.builtin_module_names[1]
        with remove_from_sys_modules(whitelist, blacklist):
            import_ = controlled_importlib.ControlledImport([whitelist])
            module = import_(whitelist, level=0)
            self.failUnlessEqual(module.__name__, whitelist)
            self.failUnlessRaises(ImportError, import_, blacklist, level=0)

    def test_frozen_whitelisting(self):
        whitelist = '__phello__'
        blacklist = ('__hello__', '__phello__.spam')
        with nested(temp_setattr(sys, 'stdout', StringIO.StringIO()),
                remove_from_sys_modules(whitelist, *blacklist)):
            import_ = controlled_importlib.ControlledImport([], [whitelist], [])
            module = import_(whitelist, level=0)
            self.failUnlessEqual(module.__name__, whitelist)
            for blacklisted in blacklist:
                self.failUnlessRaises(ImportError, import_, blacklisted, level=0)

    def test_extension_whitelisting(self):
        whitelist = 'time'
        blacklist = 'datetime'
        with remove_from_sys_modules(whitelist, blacklist):
            import_ = controlled_importlib.ControlledImport([], [], [whitelist])
            module = import_(whitelist, level=0)
            self.failUnlessEqual(module.__name__, whitelist)
            self.failUnlessRaises(ImportError, import_, blacklist, level=0)

    def test_pyc_blocking(self):
        # Importing of a .pyc file should fail.  Also, no .pyc should be
        # generated.
        with remove_from_sys_modules(self.module_name):
            import_ = controlled_importlib.ControlledImport()
            os.unlink(self.py_path)
            assert os.path.exists(self.pyc_path)
            assert not os.path.exists(self.py_path)
            self.failUnlessRaises(ImportError, import_, self.module_name,
                    level=0)

    def test_py(self):
        # Should be able to import a .py module.
        with remove_from_sys_modules(self.module_name):
            os.unlink(self.pyc_path)
            import_ = controlled_importlib.ControlledImport()
            self.failUnless(import_(self.module_name, level=0))

    def test_no_pyc_creation(self):
        # No .pyc file should be created by importing a .py file.
        with remove_from_sys_modules(self.module_name):
            os.unlink(self.pyc_path)
            assert os.path.exists(self.py_path)
            assert not os.path.exists(self.pyc_path)
            import_ = controlled_importlib.ControlledImport()
            module = import_(self.module_name, level=0)
            self.failUnless(module)
            self.failUnless(not os.path.exists(self.pyc_path))

    def test_no_loader_attribute(self):
        # No __loader__ attribute should be exposed on any module or package.
        # Purposefully skipped the sub-package to make sure that implicit
        # imports of dependencies does not leave __loader__ on by importing a
        # module in the sub-package.
        module_names = [self.top_level_module_name, self.pkg_name,
                self.pkg_module_name, self.sub_pkg_module_name]
        assert self.sub_pkg_name not in module_names
        with remove_from_sys_modules(*module_names):
            import_ = controlled_importlib.ControlledImport()
            for module_name in module_names:
                module = import_(module_name, level=0)
                self.failUnless(not hasattr(sys.modules[module_name],
                '__loader__'))
            else:
                self.failUnless(not hasattr(sys.modules[self.sub_pkg_name],
                '__loader__'))

    def test_relative_import(self):
        # A relative import within a package should not be able to circumvent
        # whitelisting.
        whitelist ='__phello__'
        blacklist = '__phello__.spam'
        blacklist_module = blacklist.split('.')[1]
        with nested(temp_setattr(sys, 'stdout', StringIO.StringIO()),
                remove_from_sys_modules(whitelist, blacklist)):
            import_ = controlled_importlib.ControlledImport((), [whitelist])
            pkg = import_(whitelist, level=0)
            pkg_fromlist = import_('', pkg.__dict__, {}, [blacklist_module], 1)
            assert pkg_fromlist.__name__ == whitelist
            self.failUnless(not hasattr(pkg_fromlist, blacklist_module))


def test_main():
    test_support.run_unittest(WhitelistTests,
                                WhitelistBuiltinTests,
                                WhitelistFrozenTests,
                                WhitelistExtensionsTests,
                                ControlledImportMethodTests,
                                ControlledImportUsageTests)


if __name__ == '__main__':
    test_main()
