1"""
2Basic statistics module.
3
4This module provides functions for calculating statistics of data, including
5averages, variance, and standard deviation.
6
7Calculating averages
8--------------------
9
10==================  =============================================
11Function            Description
12==================  =============================================
13mean                Arithmetic mean (average) of data.
14harmonic_mean       Harmonic mean of data.
15median              Median (middle value) of data.
16median_low          Low median of data.
17median_high         High median of data.
18median_grouped      Median, or 50th percentile, of grouped data.
19mode                Mode (most common value) of data.
20==================  =============================================
21
22Calculate the arithmetic mean ("the average") of data:
23
24>>> mean([-1.0, 2.5, 3.25, 5.75])
252.625
26
27
28Calculate the standard median of discrete data:
29
30>>> median([2, 3, 4, 5])
313.5
32
33
34Calculate the median, or 50th percentile, of data grouped into class intervals
35centred on the data values provided. E.g. if your data points are rounded to
36the nearest whole number:
37
38>>> median_grouped([2, 2, 3, 3, 3, 4])  #doctest: +ELLIPSIS
392.8333333333...
40
41This should be interpreted in this way: you have two data points in the class
42interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
43the class interval 3.5-4.5. The median of these data points is 2.8333...
44
45
46Calculating variability or spread
47---------------------------------
48
49==================  =============================================
50Function            Description
51==================  =============================================
52pvariance           Population variance of data.
53variance            Sample variance of data.
54pstdev              Population standard deviation of data.
55stdev               Sample standard deviation of data.
56==================  =============================================
57
58Calculate the standard deviation of sample data:
59
60>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75])  #doctest: +ELLIPSIS
614.38961843444...
62
63If you have previously calculated the mean, you can pass it as the optional
64second argument to the four "spread" functions to avoid recalculating it:
65
66>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
67>>> mu = mean(data)
68>>> pvariance(data, mu)
692.5
70
71
72Exceptions
73----------
74
75A single exception is defined: StatisticsError is a subclass of ValueError.
76
77"""
78
79__all__ = [ 'StatisticsError',
80            'pstdev', 'pvariance', 'stdev', 'variance',
81            'median',  'median_low', 'median_high', 'median_grouped',
82            'mean', 'mode', 'harmonic_mean',
83          ]
84
85import collections
86import math
87import numbers
88
89from fractions import Fraction
90from decimal import Decimal
91from itertools import groupby
92from bisect import bisect_left, bisect_right
93
94
95
96# === Exceptions ===
97
98class StatisticsError(ValueError):
99    pass
100
101
102# === Private utilities ===
103
104def _sum(data, start=0):
105    """_sum(data [, start]) -> (type, sum, count)
106
107    Return a high-precision sum of the given numeric data as a fraction,
108    together with the type to be converted to and the count of items.
109
110    If optional argument ``start`` is given, it is added to the total.
111    If ``data`` is empty, ``start`` (defaulting to 0) is returned.
112
113
114    Examples
115    --------
116
117    >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
118    (<class 'float'>, Fraction(11, 1), 5)
119
120    Some sources of round-off error will be avoided:
121
122    # Built-in sum returns zero.
123    >>> _sum([1e50, 1, -1e50] * 1000)
124    (<class 'float'>, Fraction(1000, 1), 3000)
125
126    Fractions and Decimals are also supported:
127
128    >>> from fractions import Fraction as F
129    >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
130    (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
131
132    >>> from decimal import Decimal as D
133    >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
134    >>> _sum(data)
135    (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
136
137    Mixed types are currently treated as an error, except that int is
138    allowed.
139    """
140    count = 0
141    n, d = _exact_ratio(start)
142    partials = {d: n}
143    partials_get = partials.get
144    T = _coerce(int, type(start))
145    for typ, values in groupby(data, type):
146        T = _coerce(T, typ)  # or raise TypeError
147        for n,d in map(_exact_ratio, values):
148            count += 1
149            partials[d] = partials_get(d, 0) + n
150    if None in partials:
151        # The sum will be a NAN or INF. We can ignore all the finite
152        # partials, and just look at this special one.
153        total = partials[None]
154        assert not _isfinite(total)
155    else:
156        # Sum all the partial sums using builtin sum.
157        # FIXME is this faster if we sum them in order of the denominator?
158        total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
159    return (T, total, count)
160
161
162def _isfinite(x):
163    try:
164        return x.is_finite()  # Likely a Decimal.
165    except AttributeError:
166        return math.isfinite(x)  # Coerces to float first.
167
168
169def _coerce(T, S):
170    """Coerce types T and S to a common type, or raise TypeError.
171
172    Coercion rules are currently an implementation detail. See the CoerceTest
173    test class in test_statistics for details.
174    """
175    # See http://bugs.python.org/issue24068.
176    assert T is not bool, "initial type T is bool"
177    # If the types are the same, no need to coerce anything. Put this
178    # first, so that the usual case (no coercion needed) happens as soon
179    # as possible.
180    if T is S:  return T
181    # Mixed int & other coerce to the other type.
182    if S is int or S is bool:  return T
183    if T is int:  return S
184    # If one is a (strict) subclass of the other, coerce to the subclass.
185    if issubclass(S, T):  return S
186    if issubclass(T, S):  return T
187    # Ints coerce to the other type.
188    if issubclass(T, int):  return S
189    if issubclass(S, int):  return T
190    # Mixed fraction & float coerces to float (or float subclass).
191    if issubclass(T, Fraction) and issubclass(S, float):
192        return S
193    if issubclass(T, float) and issubclass(S, Fraction):
194        return T
195    # Any other combination is disallowed.
196    msg = "don't know how to coerce %s and %s"
197    raise TypeError(msg % (T.__name__, S.__name__))
198
199
200def _exact_ratio(x):
201    """Return Real number x to exact (numerator, denominator) pair.
202
203    >>> _exact_ratio(0.25)
204    (1, 4)
205
206    x is expected to be an int, Fraction, Decimal or float.
207    """
208    try:
209        # Optimise the common case of floats. We expect that the most often
210        # used numeric type will be builtin floats, so try to make this as
211        # fast as possible.
212        if type(x) is float or type(x) is Decimal:
213            return x.as_integer_ratio()
214        try:
215            # x may be an int, Fraction, or Integral ABC.
216            return (x.numerator, x.denominator)
217        except AttributeError:
218            try:
219                # x may be a float or Decimal subclass.
220                return x.as_integer_ratio()
221            except AttributeError:
222                # Just give up?
223                pass
224    except (OverflowError, ValueError):
225        # float NAN or INF.
226        assert not _isfinite(x)
227        return (x, None)
228    msg = "can't convert type '{}' to numerator/denominator"
229    raise TypeError(msg.format(type(x).__name__))
230
231
232def _convert(value, T):
233    """Convert value to given numeric type T."""
234    if type(value) is T:
235        # This covers the cases where T is Fraction, or where value is
236        # a NAN or INF (Decimal or float).
237        return value
238    if issubclass(T, int) and value.denominator != 1:
239        T = float
240    try:
241        # FIXME: what do we do if this overflows?
242        return T(value)
243    except TypeError:
244        if issubclass(T, Decimal):
245            return T(value.numerator)/T(value.denominator)
246        else:
247            raise
248
249
250def _counts(data):
251    # Generate a table of sorted (value, frequency) pairs.
252    table = collections.Counter(iter(data)).most_common()
253    if not table:
254        return table
255    # Extract the values with the highest frequency.
256    maxfreq = table[0][1]
257    for i in range(1, len(table)):
258        if table[i][1] != maxfreq:
259            table = table[:i]
260            break
261    return table
262
263
264def _find_lteq(a, x):
265    'Locate the leftmost value exactly equal to x'
266    i = bisect_left(a, x)
267    if i != len(a) and a[i] == x:
268        return i
269    raise ValueError
270
271
272def _find_rteq(a, l, x):
273    'Locate the rightmost value exactly equal to x'
274    i = bisect_right(a, x, lo=l)
275    if i != (len(a)+1) and a[i-1] == x:
276        return i-1
277    raise ValueError
278
279
280def _fail_neg(values, errmsg='negative value'):
281    """Iterate over values, failing if any are less than zero."""
282    for x in values:
283        if x < 0:
284            raise StatisticsError(errmsg)
285        yield x
286
287
288# === Measures of central tendency (averages) ===
289
290def mean(data):
291    """Return the sample arithmetic mean of data.
292
293    >>> mean([1, 2, 3, 4, 4])
294    2.8
295
296    >>> from fractions import Fraction as F
297    >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
298    Fraction(13, 21)
299
300    >>> from decimal import Decimal as D
301    >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
302    Decimal('0.5625')
303
304    If ``data`` is empty, StatisticsError will be raised.
305    """
306    if iter(data) is data:
307        data = list(data)
308    n = len(data)
309    if n < 1:
310        raise StatisticsError('mean requires at least one data point')
311    T, total, count = _sum(data)
312    assert count == n
313    return _convert(total/n, T)
314
315
316def harmonic_mean(data):
317    """Return the harmonic mean of data.
318
319    The harmonic mean, sometimes called the subcontrary mean, is the
320    reciprocal of the arithmetic mean of the reciprocals of the data,
321    and is often appropriate when averaging quantities which are rates
322    or ratios, for example speeds. Example:
323
324    Suppose an investor purchases an equal value of shares in each of
325    three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
326    What is the average P/E ratio for the investor's portfolio?
327
328    >>> harmonic_mean([2.5, 3, 10])  # For an equal investment portfolio.
329    3.6
330
331    Using the arithmetic mean would give an average of about 5.167, which
332    is too high.
333
334    If ``data`` is empty, or any element is less than zero,
335    ``harmonic_mean`` will raise ``StatisticsError``.
336    """
337    # For a justification for using harmonic mean for P/E ratios, see
338    # http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
339    # http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
340    if iter(data) is data:
341        data = list(data)
342    errmsg = 'harmonic mean does not support negative values'
343    n = len(data)
344    if n < 1:
345        raise StatisticsError('harmonic_mean requires at least one data point')
346    elif n == 1:
347        x = data[0]
348        if isinstance(x, (numbers.Real, Decimal)):
349            if x < 0:
350                raise StatisticsError(errmsg)
351            return x
352        else:
353            raise TypeError('unsupported type')
354    try:
355        T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
356    except ZeroDivisionError:
357        return 0
358    assert count == n
359    return _convert(n/total, T)
360
361
362# FIXME: investigate ways to calculate medians without sorting? Quickselect?
363def median(data):
364    """Return the median (middle value) of numeric data.
365
366    When the number of data points is odd, return the middle data point.
367    When the number of data points is even, the median is interpolated by
368    taking the average of the two middle values:
369
370    >>> median([1, 3, 5])
371    3
372    >>> median([1, 3, 5, 7])
373    4.0
374
375    """
376    data = sorted(data)
377    n = len(data)
378    if n == 0:
379        raise StatisticsError("no median for empty data")
380    if n%2 == 1:
381        return data[n//2]
382    else:
383        i = n//2
384        return (data[i - 1] + data[i])/2
385
386
387def median_low(data):
388    """Return the low median of numeric data.
389
390    When the number of data points is odd, the middle value is returned.
391    When it is even, the smaller of the two middle values is returned.
392
393    >>> median_low([1, 3, 5])
394    3
395    >>> median_low([1, 3, 5, 7])
396    3
397
398    """
399    data = sorted(data)
400    n = len(data)
401    if n == 0:
402        raise StatisticsError("no median for empty data")
403    if n%2 == 1:
404        return data[n//2]
405    else:
406        return data[n//2 - 1]
407
408
409def median_high(data):
410    """Return the high median of data.
411
412    When the number of data points is odd, the middle value is returned.
413    When it is even, the larger of the two middle values is returned.
414
415    >>> median_high([1, 3, 5])
416    3
417    >>> median_high([1, 3, 5, 7])
418    5
419
420    """
421    data = sorted(data)
422    n = len(data)
423    if n == 0:
424        raise StatisticsError("no median for empty data")
425    return data[n//2]
426
427
428def median_grouped(data, interval=1):
429    """Return the 50th percentile (median) of grouped continuous data.
430
431    >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
432    3.7
433    >>> median_grouped([52, 52, 53, 54])
434    52.5
435
436    This calculates the median as the 50th percentile, and should be
437    used when your data is continuous and grouped. In the above example,
438    the values 1, 2, 3, etc. actually represent the midpoint of classes
439    0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
440    class 3.5-4.5, and interpolation is used to estimate it.
441
442    Optional argument ``interval`` represents the class interval, and
443    defaults to 1. Changing the class interval naturally will change the
444    interpolated 50th percentile value:
445
446    >>> median_grouped([1, 3, 3, 5, 7], interval=1)
447    3.25
448    >>> median_grouped([1, 3, 3, 5, 7], interval=2)
449    3.5
450
451    This function does not check whether the data points are at least
452    ``interval`` apart.
453    """
454    data = sorted(data)
455    n = len(data)
456    if n == 0:
457        raise StatisticsError("no median for empty data")
458    elif n == 1:
459        return data[0]
460    # Find the value at the midpoint. Remember this corresponds to the
461    # centre of the class interval.
462    x = data[n//2]
463    for obj in (x, interval):
464        if isinstance(obj, (str, bytes)):
465            raise TypeError('expected number but got %r' % obj)
466    try:
467        L = x - interval/2  # The lower limit of the median interval.
468    except TypeError:
469        # Mixed type. For now we just coerce to float.
470        L = float(x) - float(interval)/2
471
472    # Uses bisection search to search for x in data with log(n) time complexity
473    # Find the position of leftmost occurrence of x in data
474    l1 = _find_lteq(data, x)
475    # Find the position of rightmost occurrence of x in data[l1...len(data)]
476    # Assuming always l1 <= l2
477    l2 = _find_rteq(data, l1, x)
478    cf = l1
479    f = l2 - l1 + 1
480    return L + interval*(n/2 - cf)/f
481
482
483def mode(data):
484    """Return the most common data point from discrete or nominal data.
485
486    ``mode`` assumes discrete data, and returns a single value. This is the
487    standard treatment of the mode as commonly taught in schools:
488
489    >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
490    3
491
492    This also works with nominal (non-numeric) data:
493
494    >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
495    'red'
496
497    If there is not exactly one most common value, ``mode`` will raise
498    StatisticsError.
499    """
500    # Generate a table of sorted (value, frequency) pairs.
501    table = _counts(data)
502    if len(table) == 1:
503        return table[0][0]
504    elif table:
505        raise StatisticsError(
506                'no unique mode; found %d equally common values' % len(table)
507                )
508    else:
509        raise StatisticsError('no mode for empty data')
510
511
512# === Measures of spread ===
513
514# See http://mathworld.wolfram.com/Variance.html
515#     http://mathworld.wolfram.com/SampleVariance.html
516#     http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
517#
518# Under no circumstances use the so-called "computational formula for
519# variance", as that is only suitable for hand calculations with a small
520# amount of low-precision data. It has terrible numeric properties.
521#
522# See a comparison of three computational methods here:
523# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
524
525def _ss(data, c=None):
526    """Return sum of square deviations of sequence data.
527
528    If ``c`` is None, the mean is calculated in one pass, and the deviations
529    from the mean are calculated in a second pass. Otherwise, deviations are
530    calculated from ``c`` as given. Use the second case with care, as it can
531    lead to garbage results.
532    """
533    if c is None:
534        c = mean(data)
535    T, total, count = _sum((x-c)**2 for x in data)
536    # The following sum should mathematically equal zero, but due to rounding
537    # error may not.
538    U, total2, count2 = _sum((x-c) for x in data)
539    assert T == U and count == count2
540    total -=  total2**2/len(data)
541    assert not total < 0, 'negative sum of square deviations: %f' % total
542    return (T, total)
543
544
545def variance(data, xbar=None):
546    """Return the sample variance of data.
547
548    data should be an iterable of Real-valued numbers, with at least two
549    values. The optional argument xbar, if given, should be the mean of
550    the data. If it is missing or None, the mean is automatically calculated.
551
552    Use this function when your data is a sample from a population. To
553    calculate the variance from the entire population, see ``pvariance``.
554
555    Examples:
556
557    >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
558    >>> variance(data)
559    1.3720238095238095
560
561    If you have already calculated the mean of your data, you can pass it as
562    the optional second argument ``xbar`` to avoid recalculating it:
563
564    >>> m = mean(data)
565    >>> variance(data, m)
566    1.3720238095238095
567
568    This function does not check that ``xbar`` is actually the mean of
569    ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
570    impossible results.
571
572    Decimals and Fractions are supported:
573
574    >>> from decimal import Decimal as D
575    >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
576    Decimal('31.01875')
577
578    >>> from fractions import Fraction as F
579    >>> variance([F(1, 6), F(1, 2), F(5, 3)])
580    Fraction(67, 108)
581
582    """
583    if iter(data) is data:
584        data = list(data)
585    n = len(data)
586    if n < 2:
587        raise StatisticsError('variance requires at least two data points')
588    T, ss = _ss(data, xbar)
589    return _convert(ss/(n-1), T)
590
591
592def pvariance(data, mu=None):
593    """Return the population variance of ``data``.
594
595    data should be an iterable of Real-valued numbers, with at least one
596    value. The optional argument mu, if given, should be the mean of
597    the data. If it is missing or None, the mean is automatically calculated.
598
599    Use this function to calculate the variance from the entire population.
600    To estimate the variance from a sample, the ``variance`` function is
601    usually a better choice.
602
603    Examples:
604
605    >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
606    >>> pvariance(data)
607    1.25
608
609    If you have already calculated the mean of the data, you can pass it as
610    the optional second argument to avoid recalculating it:
611
612    >>> mu = mean(data)
613    >>> pvariance(data, mu)
614    1.25
615
616    This function does not check that ``mu`` is actually the mean of ``data``.
617    Giving arbitrary values for ``mu`` may lead to invalid or impossible
618    results.
619
620    Decimals and Fractions are supported:
621
622    >>> from decimal import Decimal as D
623    >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
624    Decimal('24.815')
625
626    >>> from fractions import Fraction as F
627    >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
628    Fraction(13, 72)
629
630    """
631    if iter(data) is data:
632        data = list(data)
633    n = len(data)
634    if n < 1:
635        raise StatisticsError('pvariance requires at least one data point')
636    T, ss = _ss(data, mu)
637    return _convert(ss/n, T)
638
639
640def stdev(data, xbar=None):
641    """Return the square root of the sample variance.
642
643    See ``variance`` for arguments and other details.
644
645    >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
646    1.0810874155219827
647
648    """
649    var = variance(data, xbar)
650    try:
651        return var.sqrt()
652    except AttributeError:
653        return math.sqrt(var)
654
655
656def pstdev(data, mu=None):
657    """Return the square root of the population variance.
658
659    See ``pvariance`` for arguments and other details.
660
661    >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
662    0.986893273527251
663
664    """
665    var = pvariance(data, mu)
666    try:
667        return var.sqrt()
668    except AttributeError:
669        return math.sqrt(var)
670