from __future__ import division
from statistics import describe, select, median


### UNITTESTS #################################################

import unittest
import random

def g(n):
    "iterator substitute for xrange (without defining __len__)"
    for i in xrange(n):
        yield i

class TestStats(unittest.TestCase):

    def test_mean(self):
        def mean(data):
            return describe(data)[1]
        self.assertEqual(mean(range(6)), 15/6.0)
        self.assertEqual(mean(g(6)), 15/6.0)
        self.assertEqual(mean([10]), 10)
        self.assertRaises(ValueError, mean, [])
        self.assertRaises(TypeError, mean, 'abc')

    def test_stddev(self):
        def stddev(data, *args):
            var = describe(data, *args)[2]
            return var and var ** 0.5 or var
        self.assertEqual(stddev([10,15,20]), 5)
        self.assertEqual(round(stddev([11.1,4,9,13]), 3), 3.878)
        self.assertEqual(round(stddev([11.1,4,9,13], False), 3), 3.358)
        self.assertEqual(stddev([10], False), 0.0)
        self.assertEqual(stddev([1]), None)
        self.assertRaises(ValueError, stddev, [], False)

    def test_median(self):
        data = range(10)
        random.shuffle(data)
        copy = data[:]
        self.assertEqual(median(data), 4.5)
        self.assertEqual(data, copy)
        self.assertEqual(median(g(10)), 4.5)
        data.insert(1, 10)
        self.assertEqual(median(data), 5)
        self.assertEqual(median([-50,10,2,11]), 6)
        self.assertEqual(median([30,10,2]), 10)
        self.assertRaises(ValueError, median, [])

    def test_select(self):
        testdata = range(2000) + [20] * 500 + [85] * 500
        sorteddata = sorted(testdata)
        n = len(testdata)
        a = testdata[:]
        random.shuffle(a)
        for i in xrange(100):
            nth = random.randrange(n)
            self.assertEqual(select(a, nth), sorteddata[nth])

        # Try with sorted data
        a = sorteddata[:]
        for i in xrange(100):
            nth = random.randrange(n)
            self.assertEqual(select(a, nth), sorteddata[nth])

        # Try with reverse sorted data
        a = sorteddata[:]
        a.reverse()
        for i in xrange(100):
            nth = random.randrange(n)
            self.assertEqual(select(a, nth), sorteddata[nth])

        self.assertRaises(ValueError, select, a, n+1)

if __name__ == '__main__':
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestStats))
    unittest.TextTestRunner(verbosity=2).run(suite)
