#!/usr/bin/env python2.5
""" Test suite for the fixer modules """
# Original Author: Collin Winter

# Testing imports
try:
    from tests import support
except ImportError:
    import support

# Python imports
import os
import unittest
from itertools import chain
from operator import itemgetter

# Local imports
from .. import pygram, pytree, refactor, fixer_util

class FixerTestCase(support.TestCase):
    old_version = (3, 0)
    new_version = (2, 5)

    def setUp(self, fix_list=None):
        if fix_list is None:
            fix_list = [self.fixer]
        options = {"print_function" : False}
        pkg_name = self.get_pkg_name()
        self.refactor = support.get_refactorer(fix_list, options,
                                               pkg_name=pkg_name)
        self.fixer_log = []
        self.filename = "<string>"

        for fixer in chain(self.refactor.pre_order,
                           self.refactor.post_order):
            fixer.log = self.fixer_log

    def _check(self, versions, ignore_warnings=False):
        """Verifying a fix matches before and after version

        versions is a dict mapping version tuples to sample code.

        Example:
            check({ (3, 0): 'print()',
                     (2,3): 'print' })
            # The same dict applies for 3.x to 2.x and vice versa
        """
        before = self.price_is_right(versions, self.old_version)
        after = self.price_is_right(versions, self.new_version)

        # Quit now if neither before nor after won the Price is Right.
        if before == None or after == None:
            return

        before = support.reformat(before)
        after = support.reformat(after)

        tree = self.refactor.refactor_string(before, self.filename)
        self.failUnlessEqual(after, str(tree))
        if not ignore_warnings:
            self.failUnlessEqual(self.fixer_log, [])
        return tree


    def price_is_right(self, versions, target_version):
        """Return the closest version in versions without going over target
        """
        snippet = None
        for version_key in sorted(versions.keys()):
            if version_key > target_version:
                break
            snippet = versions[version_key]
        return snippet

    def check(self, up, down):
        if self.old_version > self.new_version:
            self._check(down)
        elif self.old_version < self.new_version:
            self._check(up)
        else:
            self._check(down)
            self._check(up)

    def get_pkg_name(self):
        if self.old_version >= (3, 0):
            return 'refactor.fixes.from3'
        else:
            return 'refactor.fixes.from2'

    def warns(self, before, after, message, unchanged=False):
        tree = self._check(before, after)
        self.failUnless(message in "".join(self.fixer_log))
        if not unchanged:
            self.failUnless(tree.was_changed)

    def warns_unchanged(self, before, message):
        self.warns(before, before, message, unchanged=True)

    def unchanged(self, before, ignore_warnings=False):
        self._check(before, before)
        if not ignore_warnings:
            self.failUnlessEqual(self.fixer_log, [])

    def assert_runs_after(self, *names):
        fixes = [self.fixer]
        fixes.extend(names)
        options = {"print_function" : False}
        r = support.get_refactorer(fixes, options)
        (pre, post) = r.get_fixers()
        n = "fix_" + self.fixer
        if post and post[-1].__class__.__module__.endswith(n):
            # We're the last fixer to run
            return
        if pre and pre[-1].__class__.__module__.endswith(n) and not post:
            # We're the last in pre and post is empty
            return
        self.fail("Fixer run order (%s) is incorrect; %s should be last."\
               %(", ".join([x.__class__.__module__ for x in (pre+post)]), n))

class Test_range(FixerTestCase):
    fixer = "range"

    def test_xrange(self):
        up = {}
        down = {
            (2, 5): """x = xrange(0, 10, 2)""",
            (3, 0): """x = range(0, 10, 2)""",
            }
        self.check(up, down)

    def test_range(self):
        up = {}
        down = {
            (2, 5): """x = list(xrange(0, 10, 2))""",
            (3, 0): """x = list(range(0, 10, 2))""",
            }
        self.check(up, down)

class Test_renames(FixerTestCase):
    fixer = "renames"

    def test_maxint(self):
        up = {}
        down = {
            (2, 5): """sys.maxint""",
            (2, 6): """sys.maxsize""",
            }
        self.check(up, down)

class Test_print(FixerTestCase):
    """
    http://docs.python.org/3.0/whatsnew/3.0.html

    Old: print "The answer is", 2*2
    New: print("The answer is", 2*2)

    Old: print x,           # Trailing comma suppresses newline
    New: print(x, end=" ")  # Appends a space instead of a newline

    Old: print              # Prints a newline
    New: print()            # You must call the function!

    Old: print >>sys.stderr, "fatal error"
    New: print("fatal error", file=sys.stderr)

    Old: print (x, y)       # prints repr((x, y))
    New: print((x, y))      # Not the same as print(x, y)!
    """

    fixer = "print"

    def test_func(self):
        up = {}
        down = {
            (2, 5): """print""",
            (3, 0): """print()""",
            }
        self.check(up, down)

    def test_x(self):
        up = {}
        down = {
            (2, 5): """print x""",
            (3, 0): """print(x)""",
            }
        self.check(up, down)

    def test_str(self):
        up = {}
        down = {
            (2, 5): """print ''""",
            (3, 0): """print('')""",
            }
        self.check(up, down)

    def test_compound(self):
        up = {}
        down = {
            (2, 5): """print "The answer is", 2*2""",
            (3, 0): """print("The answer is", 2*2)""",
            }
        self.check(up, down)

    def test_end(self):
        up = {}
        down = {
            (2, 5): """print x, """,
            (3, 0): """print(x, end=" ")""",
            }
        self.check(up, down)

    def test_stderr(self):
        up = {}
        down = {
            (2, 5): """print >>sys.stderr, 'fatal error'""",
            (3, 0): """print('fatal error', file=sys.stderr)""",
            }
        self.check(up, down)

if __name__ == "__main__":
    import __main__
    support.run_all_tests(__main__)
