#!/usr/bin/env python2.5

"""Unit tests for overloading.py."""

import timeit

from overloading import overloaded

__metaclass__ = type # Use new-style classes by default

class A: pass
class B: pass
class C(A, B): pass
def default(x, y): return "default"
@overloaded
def automatic(x, y): return default(x, y)
automatic.__name__ = "automatic"
@automatic.register(A, B)
def methodAB(x, y): return "AB"
@automatic.register(A, C)
def methodAC(A, C): return "AC"
@automatic.register(B, A)
def methodBA(x, y): return "BA"
@automatic.register(C, B)
def methodCB(x, y): return "CB"

# Quick test that it works correctly
assert automatic(C(), B()) == "CB"
assert automatic(A(), C()) == "AC"
try:
    automatic(C(), C())
except TypeError:
    pass
else:
    assert False

def accelerated(x, y):
    return automatic.cache.get((type(x), type(y)), automatic)(x, y)

# Another quick test
assert accelerated(C(), B()) == "CB"
assert accelerated(A(), C()) == "AC"
try:
    accelerated(C(), C())
except TypeError:
    pass
else:
    assert False

def manual(x, y):
    if isinstance(x, C):
        if isinstance(y, B):
            return methodCB(x, y)
    if isinstance(x, B):
        if isinstance(y, A):
            return methodBA(x, y)
    if isinstance(x, A):
        if isinstance(y, C):
            return methodAC(x, y)
        if isinstance(y, B):
            return methodAB(x, y)
    return default(x, y)

# Quick test that the manual version works correctly
assert manual(C(), B()) == "CB"
assert manual(A(), C()) == "AC"

# Test fixture
def run(func, C1, C2):
    timeit.test_func = func
    timeit.test_arg1 = C1()
    timeit.test_arg2 = C2()
    t = timeit.Timer("test_func(test_arg1, test_arg2)")
    result = int(round(min(t.repeat(3, 10000))*1000))
    print "%s(%s(), %s()) - %3d msec" % (func.__name__,
                                         C1.__name__,
                                         C2.__name__,
                                         result)

# Test runs
print '-'*20
run(manual, C, B)
run(manual, A, A)
run(manual, A, B)
run(manual, A, C)

print '-'*20
run(automatic, C, B)
run(automatic, A, A)
run(automatic, A, B)
run(automatic, A, C)

print '-'*20
run(accelerated, C, B)
run(accelerated, A, A)
run(accelerated, A, B)
run(accelerated, A, C)
