import sys
import unittest

class BuiltinFrozen_Tester(unittest.TestCase):

    """Common test cases for built-in and frozen modules.

    Expected attributes:
    * self.importer
        The meta_path importer to test.
    * self.module_name
        Name of a module that the importer can import.
    * self.bad_module_names
        Sequence of module names that cannot be imported.

    """

    def test_find_module_basic(self):
        # Make sure expected modules can be found.
        self.failUnless(self.importer.find_module(self.module_name))

    def test_find_module_failures(self):
        # Modules of the wrong type should not be found.
        for module_name in self.bad_module_names:
            self.failUnlessEqual(self.importer.find_module(module_name), None)

    def test_find_module_arg_count(self):
        # Cover proper number of arguments.
        self.failUnlessRaises(TypeError, self.importer.find_module)
        self.failUnlessRaises(TypeError, self.importer.find_module,
                                self.module_name, None, None)

    def test_load_module_prev_import_check(self):
        # If a module is already in sys.modules it should be returned without
        # being re-initialized.
        self.failUnlessEqual(self.importer.load_module(self.module_name),
                             sys.modules[self.module_name])

    def test_load_module_new_import(self):
        # Make sure importing of a module that was not done before works
        # properly.
        # Do not forget to put any removed module back into sys.modules!
        # Certain modules like 'sys' have some attributes that are set
        # only once during interpreter initialization and are never set
        # again afterwards.
        mod = None
        try:
            if self.module_name in sys.modules:
                mod = sys.modules[self.module_name]
                del sys.modules[self.module_name]
            loaded_module = self.importer.load_module(self.module_name)
            self.failUnlessEqual(self.module_name, loaded_module.__name__)
        finally:
            if mod:
                sys.modules[self.module_name] = mod

    def test_load_module_ImportError(self):
        self.failUnlessRaises(ImportError,
                                self.importer.load_module,
                                self.bad_module_names[0])

    def test_load_module_arg_count(self):
        self.failUnlessRaises(TypeError, self.importer.load_module)
        self.failUnlessRaises(TypeError, self.importer.load_module,
                                self.module_name, None, None)


