from __future__ import division
import random

def describe(data, sample=True):
    """Computes the size, mean, and variance of the dataset.

    If sample is True, computes the variance with an n-1 divisor to reflect
    the additional degree of freedom in sampled data.  If sample is False,
    computes with a divisor of n, giving the variance for a complete population.
    """

    # Formula for recurrence taken from Seminumerical Algorithms, Knuth, 4.2.2
    # It has substantially better numerical performance than computing the
    # result from sum(data) and sum(x*x for x in data).
    it = iter(data)
    try:
        m = it.next()                   # running mean
    except StopIteration:
        raise ValueError('data must have at least one element')
    s = 0                               # running sum((x-mean)**2 for x in data)
    k = 1                               # number of items
    dm = 0                              # carried forward error term
    for x in it:
        k += 1

        # This block computes:  newm = (x-m)/k
        # The use of a cumulative error term improves accuracy when the mean is
        # much larger than the standard deviation.  Also, it makes the formula
        # less sensitive to data order (with sorted data resulting is larger
        # relative errors according to experiments by Chris Reedy).
        adjm = (x-m)/k - dm             # relies on true division
        newm = m + adjm
        dm = (newm - m) - adjm

        s += (x-m)*(x-newm)
        m = newm
    if sample:
        try:
            return (k, m, s / (k-1))    # sample variance
        except ZeroDivisionError:
            return (k, m, None)
    else:
        return (k, m, s / k, m)         # population variance

def select(data, n):
    """Find the nth rank ordered element (the least value has rank 0).

    Equivalent to sorted(data)[n] but runs in O(n) time.
    """
    data = list(data)
    if not 0 <= n < len(data):
        raise ValueError('not enough elements for the given rank')
    while True:
        pivot = random.choice(data)
        pcount = 0
        under, over = [], []
        uappend, oappend = under.append, over.append
        for elem in data:
            if elem < pivot:
                uappend(elem)
            elif pivot < elem:
                oappend(elem)
            else:
                pcount += 1
        if n < len(under):
            data = under
        elif n < len(under) + pcount:
            return pivot
        else:
            data = over
            n -= len(under) + pcount

def median(data):
    """Find a value such that half of the data elements are above
    the value and half are below.
    """
    try:
        n = len(data)
    except TypeError:
        data = tuple(data)
        n = len(data)
    if n == 0:
        raise ValueError('data must have at least one element')
    if n & 1 == 1:                      # if length is an odd number
        return select(data, n//2)
    return (select(data, n//2) + select(data, n//2-1)) / 2

if __name__ == '__main__':
    print stats([3,4,5]) == (3, 4.0, 1.0)
    print stats([2,4,6]) == (3, 4.0, 4.0)
    print stats([3,5,7]) == (3, 5.0, 4.0)
    from decimal import Decimal
    print stats(map(Decimal,[3,5,7])) == (3, Decimal("5"), Decimal("4"))
    print median(map(Decimal,[3,5,7])) == Decimal('5')
    print median(map(Decimal,[3,5,7,9])) == Decimal('6')
    print median(map(Decimal,[3,6,7,9])) == Decimal('6.5')
