#!/usr/bin/python
from pairing_heap import *
import random
import unittest

datasize = 256

class Holder:
    def __init__(self, item):
        self.xyzzy = item

def check_invariant(heap, node=None):
    if node == None:
        node = heap._root
    if node == None:
        return True
    n2 = node._child
    while n2:
        if heap._cmp(node._item, n2._item) >= 0:
            return False
        check_invariant(heap, n2)
        n2 = n2._sibling
    return True

# Uncomment to speed things up when doing code-coverage analysis
##def check_invariant(heap, node=None):
##    return True

class TestPairingHeap(unittest.TestCase):

    def setUp(self):
        self.data = [random.random() for i in xrange(datasize)]
        self.data_sorted = sorted(self.data)

    def test_random(self):
        # Push random numbers and pop them off, verifying all's OK.
        heap = pairing_heap()
        self.assert_(check_invariant(heap))
        for item in self.data:
            heap.insert(item)
            self.assert_(check_invariant(heap))
        results = []
        while not heap.empty():
            item = heap.extract()
            self.assert_(check_invariant(heap))
            results.append(item)
        self.assertEqual(self.data_sorted, results)

    def test_naive_nbest(self):
        heap = pairing_heap()
        for item in self.data:
            heap.insert(item)
            if len(heap) > 10:
                heap.extract()
        result = heap.extract_all()
        self.assertEqual(result, self.data_sorted[-10:])

    def test_underflow(self):
        heap = pairing_heap()
        self.assertRaises(Underflow, heap.extract)
        self.assertRaises(Underflow, heap.peek)

    def test_return_unsorted(self):
        heap = pairing_heap()
        for item in self.data:
            heap.insert(item)
        self.assertEqual(set(heap.values()), set(self.data))

    def test_implicit_creation_and_peek(self):
        heap = pairing_heap(self.data)
        self.assert_(check_invariant(heap))
        self.assertEqual(heap.peek(), self.data_sorted[0])
        self.assert_(check_invariant(heap))
        result = heap.extract_all()
        self.assertEqual(result, self.data_sorted)

    def test_extract_with_arg(self):
        heap = pairing_heap(self.data)
        low10 = heap.extract(10)
        self.assert_(check_invariant(heap))
        self.assertEqual(low10, self.data_sorted[:10])

    def test_deleting_from_the_middle(self):
        heap = pairing_heap()
        nodes = [heap.insert(item) for item in self.data]
        data2 = self.data[:]
        count = 0
        for i in xrange(datasize-1,-1,-5):
            heap.delete(nodes[i])
            del nodes[i]
            self.assert_(check_invariant(heap))
            count += 1
            self.assertEqual(datasize-count, len(heap))
            del data2[i]
        result = heap.extract_all()
        data2_sorted = data2[:]
        data2_sorted.sort()
        self.assertEqual(data2_sorted, result)

    def test_deleting_everything_in_random_order(self):
        heap = pairing_heap()
        nodes = [heap.insert(item) for item in self.data]
        for i in xrange(datasize):
            t = random.randrange(len(heap))
            heap.delete(nodes[t])
            del nodes[t]
            self.assert_(check_invariant(heap))
        self.assert_(heap.empty())

    def test_adjust_key(self):
        heap = pairing_heap()
        nodes = [heap.insert(item) for item in self.data]
        data2 = []
        for i in xrange(datasize):
            data2.append(min(self.data[i], random.random()))
        idx = range(datasize)
        random.shuffle(idx)
        for i in idx:
            heap.adjust_key(nodes[i], data2[i])
            self.assert_(check_invariant(heap))
        self.assertEqual(sorted(data2), heap.extract_all())

    def test_build_tree_by_melding(self):
        heap = pairing_heap()
        i = 0
        while i < datasize:
            j = max(random.randrange(min(datasize - i, 5)), 1)
            heap2 = pairing_heap(self.data[i:i+j])
            i += j
            heap.meld(heap2)
            self.assert_(check_invariant(heap))
        self.assertEqual(self.data_sorted, heap.extract_all())

    def test_cmpfunc(self):
        reverse = lambda x, y: -cmp(x,y)
        heap = pairing_heap(self.data, cmpfunc=reverse)
        self.assertEqual(sorted(self.data, reverse=True), heap.extract_all())

    def test_reverse(self):
        heap = pairing_heap(self.data, reverse=True)
        self.assertEqual(sorted(self.data, reverse=True), heap.extract_all())

    def test_cmpfunc_with_reverse(self):
        reverse = lambda x, y: -cmp(x,y)
        heap = pairing_heap(self.data, cmpfunc=reverse, reverse=True)
        self.assertEqual(self.data_sorted, heap.extract_all())

    def test_key(self):
        holders = [Holder(item) for item in self.data]
        heap = pairing_heap(holders, key='xyzzy')
        results = heap.extract_all()
        self.assertEqual(self.data_sorted,[item.xyzzy for item in results])

    def test_extract_values_from_nodes(self):
        heap = pairing_heap()
        nodes = [heap.insert(datum) for datum in self.data]
        for i in xrange(len(nodes)):
            self.assertEqual(nodes[i].value(), self.data[i])
        self.assert_(check_invariant(heap))

    def test_meld_with_incompatible_heap(self):
        heap1 = pairing_heap()
        reverse = lambda x, y: -cmp(x,y)
        self.assertRaises(IncompatibleHeaps, heap1.meld, pairing_heap(cmpfunc = reverse))
        self.assertRaises(IncompatibleHeaps, heap1.meld, pairing_heap(reverse = True))
        self.assertRaises(IncompatibleHeaps, heap1.meld, pairing_heap(key = 'xyzzy'))


    def test_adj_key_wrong_direction(self):
        heap = pairing_heap ()
        nodes = [heap.insert(datum) for datum in range(5)]
        self.assertRaises(WrongAdjustKeyDirection, heap.adjust_key, nodes[4], 100)

    def test_adjust_key_on_existing_root(self):
        heap = pairing_heap ()
        data = self.data
        nodes = [heap.insert(datum) for datum in data]
        low = 0
        for i in xrange(1, len(data)):
            if data[i] == data[low]:
                low_list.append (i)
            if data[i] < data[low]:
                low = i
                low_list = [i]
        for i in low_list:
            heap.adjust_key(nodes[i], data[i] - random.random())

    def test_bogus_node_parameter(self):
        heap = pairing_heap ()
        nodes = [heap.insert(datum) for datum in self.data]
        self.assertRaises(TypeError, heap.adjust_key, "muhahahah!", 5)
        self.assert_(check_invariant(heap))
        self.assertRaises(TypeError, heap.delete, "muhahahah!")


    def test_semibogus_node_parameter(self):
        holders = [Holder(item) for item in self.data]
        heap = pairing_heap(key='xyzzy')
        nodes = [heap.insert(datum) for datum in holders]
        for i in xrange(len(holders)):
            holders[i].node = nodes[i]
        top = heap.extract ()
        self.assertRaises(WrongHeap, heap.delete, top.node)
        self.assert_(check_invariant(heap))


if __name__ == '__main__':
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestPairingHeap))
    unittest.TextTestRunner(verbosity=2).run(suite)
