"""Benchmark importing.

There are two aspects to importing that these benchmarks should cover.  One is
the importing of various types of modules (such as extension modules, frozen
modules, etc.).  The other aspect is the general mechanism of importing (such as
walking sys.path, using path_hooks, etc.).

In general, though, one encompasses the other.  By minimizing the entries in
sys.path and sys.meta_path when benchmarking the import of different types of
files you also inadvertently benchmark the fundamental import machinery as well.
This means that only extreme cases of the import machinery need to be explicitly
tested.

Another thing to keep in mind is expected changed in Py3K in terms of the
'level' argument.  Since Py3K is a definite target for this code the 'level'
value should be set properly and not be allowed to default to -1.  While
benchmarking the classic relative import style would be helpful it should be
its own benchmark.

"""
from py_compile import compile as compile_to_pyc
import os
from shutil import rmtree
import StringIO
import sys
import tempfile
from timeit import Timer

def import_and_clear_timer(module_name, globals_={}, locals_={}, fromlist=[],
                            level=0, absolute_name=None):
    import_stmt = ('__import__(%r, %r, %r, %r, %s)'  %
                    (module_name, globals_, locals_, fromlist, level))
    # Need to get absolute module name in case benchmark is doing a relative
    # import.
    module = eval(import_stmt)
    del_format = 'del sys.modules[%r]'
    if not absolute_name:
        absolute_name = module.__name__
    else:
        del_format += '; delattr(sys.modules[%r], %r)' % (module.__name__,
                                                            fromlist[0])
    assert absolute_name in sys.modules
    del_stmt = del_format % absolute_name
    exec del_stmt in globals(), {}
    stmt = import_stmt + '; ' + del_stmt
    return Timer(stmt, 'import sys')

def save_import_state(fxn):
    """Backup and restore the import state."""
    def inner(*args, **kwargs):
        modules = sys.modules.copy()
        path = sys.path[:]
        meta_path = sys.meta_path[:]
        path_importer_cache = sys.path_importer_cache.copy()
        path_hooks = sys.path_hooks[:]
        try:
            return fxn(*args, **kwargs)
        finally:
            sys.modules.clear()
            sys.modules.update(modules)
            sys.path = path
            sys.meta_path = meta_path
            sys.path_importer_cache.clear()
            sys.path_importer_cache.update(path_importer_cache)
            sys.path_hooks = path_hooks

    inner.__name__ = fxn.__name__
    return inner

class PyPycFiles(object):

    """Create .py and .pyc files for importing."""
    
    def __init__(self, py=True, pyc=True, module_name="tempmod"):
        """Specify what type of files to create."""
        self.py = py
        self.pyc = pyc
        self.module_name = module_name
        
    def __enter__(self):
        """Create requested files."""
        self.directory = tempfile.gettempdir()
        self.py_path = os.path.join(self.directory, self.module_name + '.py')
        self.pyc_path = self.py_path + ('c' if __debug__ else 'o')
        try:
            with open(self.py_path, 'w') as py_file:
                py_file.write("# Temporary file for benchmarking import.")
            if self.pyc:
                compile_to_pyc(self.py_path, doraise=True)
            if not self.py:
                os.remove(self.py_path)
            sys.path.append(self.directory)
            return self
        except Exception:
            self.__exit__()
            raise
        
    def __exit__(self, *args):
        """Clean up created state."""
        if os.path.exists(self.py_path):
            os.remove(self.py_path)
        if os.path.exists(self.pyc_path):
            os.remove(self.pyc_path)
            
class PyPycPackage(object):

    """Create .py files for testing the importing of packages."""
    
    def __init__(self, pkg_name="testpkg", module_name="testmod"):
        """Specify the names for the package and module within the package."""
        self.pkg_name = pkg_name
        self.module_name = module_name
        
    def __enter__(self):
        """Create the package and module."""
        self.directory = tempfile.gettempdir()
        self.pkg_path = os.path.join(self.directory, self.pkg_name)
        os.mkdir(self.pkg_path)
        try:
            self.pkg_init_path = os.path.join(self.pkg_path, '__init__.py')
            with open(self.pkg_init_path, 'w') as init_file:
                init_file.write('# This file is used for benchmarking package '
                                'imports')
            self.full_module_name = self.pkg_name + '.' + self.module_name
            self.module_path = os.path.join(self.pkg_path, self.module_name+'.py')
            with open(self.module_path, 'w') as module_file:
                module_file.write('# This file is used to benchmark importing '
                                    'modules in a package.')
            sys.path.append(self.directory)
            return self
        except Exception:
            self.__exit__()
            raise
    
    def __exit__(self, *args):
        """Cleanup our mess."""
        rmtree(self.pkg_path)


@save_import_state
def bench_sys_modules(repetitions, iterations):
    """Benchmark returning a module from sys.modules."""
    sys.path = []
    sys.meta_path = []
    sys.path_importer_cache.clear()
    with PyPycFiles() as file_state:
        # Force the module into sys.modules .
        __import__(file_state.module_name, {})
        timer = Timer("__import__(%r, {}, {}, [], 0)" % file_state.module_name)
        return timer.repeat(repetitions, iterations)

@save_import_state
def bench_nonexistent_module(repetitions, iterations):
    """How expensive is an import failure (with a single entry on sys.path)?"""
    # Verify that the module does not exist.
    sys.path = []
    sys.meta_path = []
    bad_name = 'sadfflkjsdf'
    # Have at least one entry on sys.path.
    with PyPycFiles():
        # Make sure the fake module name is not actually real.
        try:
            __import__(bad_name)
        except ImportError:
            pass
        else:
            raise Exception("supposed non-existent module actually exists")
        timer = Timer("try:__import__(%r, {}, {}, [], 0)\nexcept ImportError:pass" %
                        bad_name)
        return timer.repeat(repetitions, iterations)

def py_pyc_module_benchmark(py, pyc):
    @save_import_state
    def inner(repetitions, iterations):
        sys.path = []
        sys.meta_path = []
        with PyPycFiles(py, pyc) as file_state:
            timer = import_and_clear_timer(file_state.module_name)
            return timer.repeat(repetitions, iterations)
    return inner
     
bench_pyc_with_py = py_pyc_module_benchmark(py=True, pyc=True)
bench_pyc_without_py = py_pyc_module_benchmark(py=False, pyc=True)
bench_py_without_pyc = py_pyc_module_benchmark(py=True, pyc=False)

@save_import_state
def bench_builtins(repetitions, iterations):
    """Benchmark the importation of a built-in module."""
    sys.meta_path = []
    sys.path = []
    timer = import_and_clear_timer('xxsubtype')
    return timer.repeat(repetitions, iterations)
 
@save_import_state 
def bench_frozen(repetitions, iterations):
    """Benchmark the importing of frozen modules."""
    sys.path = []
    sys.meta_path = []
    # Must suppress output from importing __hello__.
    sys.stdout = StringIO.StringIO()
    try:
        timer = import_and_clear_timer('__hello__')
        return timer.repeat(repetitions, iterations)
    finally:
        sys.stdout = sys.__stdout__

@save_import_state
def bench_extension(repetitions, iterations):
    """Benchmark the importing of an extension module."""
    # Choose an extension module that is always included with Python.
    module = 'datetime'
    # Import the object to find out which entry in sys.path contains the module.
    module_object = __import__(module, {}, {}, [], 0)
    directory = os.path.split(module_object.__file__)[0]
    sys.meta_path = []
    sys.path = [directory]
    timer = import_and_clear_timer(module)
    return timer.repeat(repetitions, iterations)
    
@save_import_state
def bench_long_sys_path(repetitions, iterations, extra_path_entries=20):
    """See the impact of having a large number of entries on sys.path that do
    not have the module."""
    def succeed_hook(path):
        class FailImporter(object):
            def find_module(name, path=None):
                raise ImportError
        return FailImporter()
    sys.path = []
    sys.meta_path = []
    with PyPycFiles() as file_state:
        for entry in xrange(extra_path_entries):
            sys.path.insert(0, str(entry)+'dummy_entry')
        timer = import_and_clear_timer(file_state.module_name)
        return timer.repeat(repetitions, iterations)
        
@save_import_state
def bench_package(repetitions, iterations):
    """Benchmark the importing of a package."""
    sys.path = []
    sys.meta_path = []
    with PyPycPackage() as pkg_details:
        timer = import_and_clear_timer(pkg_details.pkg_name)
        return timer.repeat(repetitions, iterations)
        
@save_import_state
def bench_module_in_package(repetitions, iterations):
    """Benchmark importing a module that is contained within a package."""
    sys.path = []
    sys.meta_path = []
    with PyPycPackage() as pkg_details:
        # Import the package first to prime sys.modules.
        __import__(pkg_details.pkg_name, {}, {}, [], 0)
        # Shouldn't have to worry about attribute pre-existing on the package
        # for the module.  It should get reset by the import.
        timer = import_and_clear_timer(pkg_details.full_module_name)
        return timer.repeat(repetitions, iterations)
        
@save_import_state
def bench_relative_import_in_package(repetitions, iterations):
    """Benchmark importing a module within a package using a relative path."""
    sys.path = []
    sys.meta_path = []
    with PyPycPackage() as pkg_details:
        pkg = __import__(pkg_details.pkg_name, {}, {}, [], 0)
        globals_ = {'__name__': pkg.__name__, '__path__':pkg.__path__}
        timer = import_and_clear_timer('', globals_, {},
                                        [pkg_details.module_name], 1,
                                        pkg_details.full_module_name)
        return timer.repeat(repetitions, iterations)


def display_results(fxn, name, spacing, repetitions, iterations):
    timings = fxn(repetitions, iterations)
    print "%s %s" % (name.ljust(spacing),
                             [int(result*1000) for result in timings])

def main(tests=None):
    max_name_len = max(len(name) for name in globals().keys())
    repetitions = 3
    iterations = 10000
    print ("Repeating tests %s times with %s iterations ..." %
            (repetitions, iterations))
    print
    print "%s time per repetition (in ms)" % "benchmark".ljust(max_name_len)
    print '-' * 50
    if tests is None:
        tests = (item for item in globals().iterkeys()
                    if item.startswith('bench_'))
    globals_ = globals()
    for name in tests:
        item = globals_[name]
        display_results(item, name, max_name_len, repetitions, iterations)

            
if __name__ == '__main__':
# XXX support specifying number of repetitions and iterations from the command-line.
    main(sys.argv[1:] if len(sys.argv) > 1 else None)
