1 """
2 Basic statistics module.
3
4 This module provides functions for calculating statistics of data, including
5 averages, variance, and standard deviation.
6
7 Calculating averages
8 --------------------
9
10 ================== ==================================================
11 Function Description
12 ================== ==================================================
13 mean Arithmetic mean (average) of data.
14 fmean Fast, floating point arithmetic mean.
15 geometric_mean Geometric mean of data.
16 harmonic_mean Harmonic mean of data.
17 median Median (middle value) of data.
18 median_low Low median of data.
19 median_high High median of data.
20 median_grouped Median, or 50th percentile, of grouped data.
21 mode Mode (most common value) of data.
22 multimode List of modes (most common values of data).
23 quantiles Divide data into intervals with equal probability.
24 ================== ==================================================
25
26 Calculate the arithmetic mean ("the average") of data:
27
28 >>> mean([-1.0, 2.5, 3.25, 5.75])
29 2.625
30
31
32 Calculate the standard median of discrete data:
33
34 >>> median([2, 3, 4, 5])
35 3.5
36
37
38 Calculate the median, or 50th percentile, of data grouped into class intervals
39 centred on the data values provided. E.g. if your data points are rounded to
40 the nearest whole number:
41
42 >>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
43 2.8333333333...
44
45 This should be interpreted in this way: you have two data points in the class
46 interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
47 the class interval 3.5-4.5. The median of these data points is 2.8333...
48
49
50 Calculating variability or spread
51 ---------------------------------
52
53 ================== =============================================
54 Function Description
55 ================== =============================================
56 pvariance Population variance of data.
57 variance Sample variance of data.
58 pstdev Population standard deviation of data.
59 stdev Sample standard deviation of data.
60 ================== =============================================
61
62 Calculate the standard deviation of sample data:
63
64 >>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
65 4.38961843444...
66
67 If you have previously calculated the mean, you can pass it as the optional
68 second argument to the four "spread" functions to avoid recalculating it:
69
70 >>> data = [1, 2, 2, 4, 4, 4, 5, 6]
71 >>> mu = mean(data)
72 >>> pvariance(data, mu)
73 2.5
74
75
76 Statistics for relations between two inputs
77 -------------------------------------------
78
79 ================== ====================================================
80 Function Description
81 ================== ====================================================
82 covariance Sample covariance for two variables.
83 correlation Pearson's correlation coefficient for two variables.
84 linear_regression Intercept and slope for simple linear regression.
85 ================== ====================================================
86
87 Calculate covariance, Pearson's correlation, and simple linear regression
88 for two inputs:
89
90 >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
91 >>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
92 >>> covariance(x, y)
93 0.75
94 >>> correlation(x, y) #doctest: +ELLIPSIS
95 0.31622776601...
96 >>> linear_regression(x, y) #doctest:
97 LinearRegression(slope=0.1, intercept=1.5)
98
99
100 Exceptions
101 ----------
102
103 A single exception is defined: StatisticsError is a subclass of ValueError.
104
105 """
106
107 __all__ = [
108 'NormalDist',
109 'StatisticsError',
110 'correlation',
111 'covariance',
112 'fmean',
113 'geometric_mean',
114 'harmonic_mean',
115 'linear_regression',
116 'mean',
117 'median',
118 'median_grouped',
119 'median_high',
120 'median_low',
121 'mode',
122 'multimode',
123 'pstdev',
124 'pvariance',
125 'quantiles',
126 'stdev',
127 'variance',
128 ]
129
130 import math
131 import numbers
132 import random
133 import sys
134
135 from fractions import Fraction
136 from decimal import Decimal
137 from itertools import count, groupby, repeat
138 from bisect import bisect_left, bisect_right
139 from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
140 from functools import reduce
141 from operator import itemgetter
142 from collections import Counter, namedtuple, defaultdict
143
144 _SQRT2 = sqrt(2.0)
145
146 # === Exceptions ===
147
148 class ESC[4;38;5;81mStatisticsError(ESC[4;38;5;149mValueError):
149 pass
150
151
152 # === Private utilities ===
153
154 def _sum(data):
155 """_sum(data) -> (type, sum, count)
156
157 Return a high-precision sum of the given numeric data as a fraction,
158 together with the type to be converted to and the count of items.
159
160 Examples
161 --------
162
163 >>> _sum([3, 2.25, 4.5, -0.5, 0.25])
164 (<class 'float'>, Fraction(19, 2), 5)
165
166 Some sources of round-off error will be avoided:
167
168 # Built-in sum returns zero.
169 >>> _sum([1e50, 1, -1e50] * 1000)
170 (<class 'float'>, Fraction(1000, 1), 3000)
171
172 Fractions and Decimals are also supported:
173
174 >>> from fractions import Fraction as F
175 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
176 (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
177
178 >>> from decimal import Decimal as D
179 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
180 >>> _sum(data)
181 (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
182
183 Mixed types are currently treated as an error, except that int is
184 allowed.
185 """
186 count = 0
187 types = set()
188 types_add = types.add
189 partials = {}
190 partials_get = partials.get
191 for typ, values in groupby(data, type):
192 types_add(typ)
193 for n, d in map(_exact_ratio, values):
194 count += 1
195 partials[d] = partials_get(d, 0) + n
196 if None in partials:
197 # The sum will be a NAN or INF. We can ignore all the finite
198 # partials, and just look at this special one.
199 total = partials[None]
200 assert not _isfinite(total)
201 else:
202 # Sum all the partial sums using builtin sum.
203 total = sum(Fraction(n, d) for d, n in partials.items())
204 T = reduce(_coerce, types, int) # or raise TypeError
205 return (T, total, count)
206
207
208 def _ss(data, c=None):
209 """Return the exact mean and sum of square deviations of sequence data.
210
211 Calculations are done in a single pass, allowing the input to be an iterator.
212
213 If given *c* is used the mean; otherwise, it is calculated from the data.
214 Use the *c* argument with care, as it can lead to garbage results.
215
216 """
217 if c is not None:
218 T, ssd, count = _sum((d := x - c) * d for x in data)
219 return (T, ssd, c, count)
220 count = 0
221 types = set()
222 types_add = types.add
223 sx_partials = defaultdict(int)
224 sxx_partials = defaultdict(int)
225 for typ, values in groupby(data, type):
226 types_add(typ)
227 for n, d in map(_exact_ratio, values):
228 count += 1
229 sx_partials[d] += n
230 sxx_partials[d] += n * n
231 if not count:
232 ssd = c = Fraction(0)
233 elif None in sx_partials:
234 # The sum will be a NAN or INF. We can ignore all the finite
235 # partials, and just look at this special one.
236 ssd = c = sx_partials[None]
237 assert not _isfinite(ssd)
238 else:
239 sx = sum(Fraction(n, d) for d, n in sx_partials.items())
240 sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
241 # This formula has poor numeric properties for floats,
242 # but with fractions it is exact.
243 ssd = (count * sxx - sx * sx) / count
244 c = sx / count
245 T = reduce(_coerce, types, int) # or raise TypeError
246 return (T, ssd, c, count)
247
248
249 def _isfinite(x):
250 try:
251 return x.is_finite() # Likely a Decimal.
252 except AttributeError:
253 return math.isfinite(x) # Coerces to float first.
254
255
256 def _coerce(T, S):
257 """Coerce types T and S to a common type, or raise TypeError.
258
259 Coercion rules are currently an implementation detail. See the CoerceTest
260 test class in test_statistics for details.
261 """
262 # See http://bugs.python.org/issue24068.
263 assert T is not bool, "initial type T is bool"
264 # If the types are the same, no need to coerce anything. Put this
265 # first, so that the usual case (no coercion needed) happens as soon
266 # as possible.
267 if T is S: return T
268 # Mixed int & other coerce to the other type.
269 if S is int or S is bool: return T
270 if T is int: return S
271 # If one is a (strict) subclass of the other, coerce to the subclass.
272 if issubclass(S, T): return S
273 if issubclass(T, S): return T
274 # Ints coerce to the other type.
275 if issubclass(T, int): return S
276 if issubclass(S, int): return T
277 # Mixed fraction & float coerces to float (or float subclass).
278 if issubclass(T, Fraction) and issubclass(S, float):
279 return S
280 if issubclass(T, float) and issubclass(S, Fraction):
281 return T
282 # Any other combination is disallowed.
283 msg = "don't know how to coerce %s and %s"
284 raise TypeError(msg % (T.__name__, S.__name__))
285
286
287 def _exact_ratio(x):
288 """Return Real number x to exact (numerator, denominator) pair.
289
290 >>> _exact_ratio(0.25)
291 (1, 4)
292
293 x is expected to be an int, Fraction, Decimal or float.
294 """
295
296 # XXX We should revisit whether using fractions to accumulate exact
297 # ratios is the right way to go.
298
299 # The integer ratios for binary floats can have numerators or
300 # denominators with over 300 decimal digits. The problem is more
301 # acute with decimal floats where the default decimal context
302 # supports a huge range of exponents from Emin=-999999 to
303 # Emax=999999. When expanded with as_integer_ratio(), numbers like
304 # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
305 # numerators or denominators that will slow computation.
306
307 # When the integer ratios are accumulated as fractions, the size
308 # grows to cover the full range from the smallest magnitude to the
309 # largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
310 # has a 616 digit numerator. Likewise,
311 # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
312 # has 10,003 digit numerator.
313
314 # This doesn't seem to have been problem in practice, but it is a
315 # potential pitfall.
316
317 try:
318 return x.as_integer_ratio()
319 except AttributeError:
320 pass
321 except (OverflowError, ValueError):
322 # float NAN or INF.
323 assert not _isfinite(x)
324 return (x, None)
325 try:
326 # x may be an Integral ABC.
327 return (x.numerator, x.denominator)
328 except AttributeError:
329 msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
330 raise TypeError(msg)
331
332
333 def _convert(value, T):
334 """Convert value to given numeric type T."""
335 if type(value) is T:
336 # This covers the cases where T is Fraction, or where value is
337 # a NAN or INF (Decimal or float).
338 return value
339 if issubclass(T, int) and value.denominator != 1:
340 T = float
341 try:
342 # FIXME: what do we do if this overflows?
343 return T(value)
344 except TypeError:
345 if issubclass(T, Decimal):
346 return T(value.numerator) / T(value.denominator)
347 else:
348 raise
349
350
351 def _fail_neg(values, errmsg='negative value'):
352 """Iterate over values, failing if any are less than zero."""
353 for x in values:
354 if x < 0:
355 raise StatisticsError(errmsg)
356 yield x
357
358
359 def _rank(data, /, *, key=None, reverse=False, ties='average', start=1) -> list[float]:
360 """Rank order a dataset. The lowest value has rank 1.
361
362 Ties are averaged so that equal values receive the same rank:
363
364 >>> data = [31, 56, 31, 25, 75, 18]
365 >>> _rank(data)
366 [3.5, 5.0, 3.5, 2.0, 6.0, 1.0]
367
368 The operation is idempotent:
369
370 >>> _rank([3.5, 5.0, 3.5, 2.0, 6.0, 1.0])
371 [3.5, 5.0, 3.5, 2.0, 6.0, 1.0]
372
373 It is possible to rank the data in reverse order so that the
374 highest value has rank 1. Also, a key-function can extract
375 the field to be ranked:
376
377 >>> goals = [('eagles', 45), ('bears', 48), ('lions', 44)]
378 >>> _rank(goals, key=itemgetter(1), reverse=True)
379 [2.0, 1.0, 3.0]
380
381 Ranks are conventionally numbered starting from one; however,
382 setting *start* to zero allows the ranks to be used as array indices:
383
384 >>> prize = ['Gold', 'Silver', 'Bronze', 'Certificate']
385 >>> scores = [8.1, 7.3, 9.4, 8.3]
386 >>> [prize[int(i)] for i in _rank(scores, start=0, reverse=True)]
387 ['Bronze', 'Certificate', 'Gold', 'Silver']
388
389 """
390 # If this function becomes public at some point, more thought
391 # needs to be given to the signature. A list of ints is
392 # plausible when ties is "min" or "max". When ties is "average",
393 # either list[float] or list[Fraction] is plausible.
394
395 # Default handling of ties matches scipy.stats.mstats.spearmanr.
396 if ties != 'average':
397 raise ValueError(f'Unknown tie resolution method: {ties!r}')
398 if key is not None:
399 data = map(key, data)
400 val_pos = sorted(zip(data, count()), reverse=reverse)
401 i = start - 1
402 result = [0] * len(val_pos)
403 for _, g in groupby(val_pos, key=itemgetter(0)):
404 group = list(g)
405 size = len(group)
406 rank = i + (size + 1) / 2
407 for value, orig_pos in group:
408 result[orig_pos] = rank
409 i += size
410 return result
411
412
413 def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
414 """Square root of n/m, rounded to the nearest integer using round-to-odd."""
415 # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
416 a = math.isqrt(n // m)
417 return a | (a*a*m != n)
418
419
420 # For 53 bit precision floats, the bit width used in
421 # _float_sqrt_of_frac() is 109.
422 _sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
423
424
425 def _float_sqrt_of_frac(n: int, m: int) -> float:
426 """Square root of n/m as a float, correctly rounded."""
427 # See principle and proof sketch at: https://bugs.python.org/msg407078
428 q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
429 if q >= 0:
430 numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
431 denominator = 1
432 else:
433 numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
434 denominator = 1 << -q
435 return numerator / denominator # Convert to float
436
437
438 def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
439 """Square root of n/m as a Decimal, correctly rounded."""
440 # Premise: For decimal, computing (n/m).sqrt() can be off
441 # by 1 ulp from the correctly rounded result.
442 # Method: Check the result, moving up or down a step if needed.
443 if n <= 0:
444 if not n:
445 return Decimal('0.0')
446 n, m = -n, -m
447
448 root = (Decimal(n) / Decimal(m)).sqrt()
449 nr, dr = root.as_integer_ratio()
450
451 plus = root.next_plus()
452 np, dp = plus.as_integer_ratio()
453 # test: n / m > ((root + plus) / 2) ** 2
454 if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
455 return plus
456
457 minus = root.next_minus()
458 nm, dm = minus.as_integer_ratio()
459 # test: n / m < ((root + minus) / 2) ** 2
460 if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
461 return minus
462
463 return root
464
465
466 # === Measures of central tendency (averages) ===
467
468 def mean(data):
469 """Return the sample arithmetic mean of data.
470
471 >>> mean([1, 2, 3, 4, 4])
472 2.8
473
474 >>> from fractions import Fraction as F
475 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
476 Fraction(13, 21)
477
478 >>> from decimal import Decimal as D
479 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
480 Decimal('0.5625')
481
482 If ``data`` is empty, StatisticsError will be raised.
483 """
484 T, total, n = _sum(data)
485 if n < 1:
486 raise StatisticsError('mean requires at least one data point')
487 return _convert(total / n, T)
488
489
490 def fmean(data, weights=None):
491 """Convert data to floats and compute the arithmetic mean.
492
493 This runs faster than the mean() function and it always returns a float.
494 If the input dataset is empty, it raises a StatisticsError.
495
496 >>> fmean([3.5, 4.0, 5.25])
497 4.25
498 """
499 if weights is None:
500 try:
501 n = len(data)
502 except TypeError:
503 # Handle iterators that do not define __len__().
504 n = 0
505 def count(iterable):
506 nonlocal n
507 for n, x in enumerate(iterable, start=1):
508 yield x
509 data = count(data)
510 total = fsum(data)
511 if not n:
512 raise StatisticsError('fmean requires at least one data point')
513 return total / n
514 if not isinstance(weights, (list, tuple)):
515 weights = list(weights)
516 try:
517 num = sumprod(data, weights)
518 except ValueError:
519 raise StatisticsError('data and weights must be the same length')
520 den = fsum(weights)
521 if not den:
522 raise StatisticsError('sum of weights must be non-zero')
523 return num / den
524
525
526 def geometric_mean(data):
527 """Convert data to floats and compute the geometric mean.
528
529 Raises a StatisticsError if the input dataset is empty,
530 if it contains a zero, or if it contains a negative value.
531
532 No special efforts are made to achieve exact results.
533 (However, this may change in the future.)
534
535 >>> round(geometric_mean([54, 24, 36]), 9)
536 36.0
537 """
538 try:
539 return exp(fmean(map(log, data)))
540 except ValueError:
541 raise StatisticsError('geometric mean requires a non-empty dataset '
542 'containing positive numbers') from None
543
544
545 def harmonic_mean(data, weights=None):
546 """Return the harmonic mean of data.
547
548 The harmonic mean is the reciprocal of the arithmetic mean of the
549 reciprocals of the data. It can be used for averaging ratios or
550 rates, for example speeds.
551
552 Suppose a car travels 40 km/hr for 5 km and then speeds-up to
553 60 km/hr for another 5 km. What is the average speed?
554
555 >>> harmonic_mean([40, 60])
556 48.0
557
558 Suppose a car travels 40 km/hr for 5 km, and when traffic clears,
559 speeds-up to 60 km/hr for the remaining 30 km of the journey. What
560 is the average speed?
561
562 >>> harmonic_mean([40, 60], weights=[5, 30])
563 56.0
564
565 If ``data`` is empty, or any element is less than zero,
566 ``harmonic_mean`` will raise ``StatisticsError``.
567 """
568 if iter(data) is data:
569 data = list(data)
570 errmsg = 'harmonic mean does not support negative values'
571 n = len(data)
572 if n < 1:
573 raise StatisticsError('harmonic_mean requires at least one data point')
574 elif n == 1 and weights is None:
575 x = data[0]
576 if isinstance(x, (numbers.Real, Decimal)):
577 if x < 0:
578 raise StatisticsError(errmsg)
579 return x
580 else:
581 raise TypeError('unsupported type')
582 if weights is None:
583 weights = repeat(1, n)
584 sum_weights = n
585 else:
586 if iter(weights) is weights:
587 weights = list(weights)
588 if len(weights) != n:
589 raise StatisticsError('Number of weights does not match data size')
590 _, sum_weights, _ = _sum(w for w in _fail_neg(weights, errmsg))
591 try:
592 data = _fail_neg(data, errmsg)
593 T, total, count = _sum(w / x if w else 0 for w, x in zip(weights, data))
594 except ZeroDivisionError:
595 return 0
596 if total <= 0:
597 raise StatisticsError('Weighted sum must be positive')
598 return _convert(sum_weights / total, T)
599
600 # FIXME: investigate ways to calculate medians without sorting? Quickselect?
601 def median(data):
602 """Return the median (middle value) of numeric data.
603
604 When the number of data points is odd, return the middle data point.
605 When the number of data points is even, the median is interpolated by
606 taking the average of the two middle values:
607
608 >>> median([1, 3, 5])
609 3
610 >>> median([1, 3, 5, 7])
611 4.0
612
613 """
614 data = sorted(data)
615 n = len(data)
616 if n == 0:
617 raise StatisticsError("no median for empty data")
618 if n % 2 == 1:
619 return data[n // 2]
620 else:
621 i = n // 2
622 return (data[i - 1] + data[i]) / 2
623
624
625 def median_low(data):
626 """Return the low median of numeric data.
627
628 When the number of data points is odd, the middle value is returned.
629 When it is even, the smaller of the two middle values is returned.
630
631 >>> median_low([1, 3, 5])
632 3
633 >>> median_low([1, 3, 5, 7])
634 3
635
636 """
637 data = sorted(data)
638 n = len(data)
639 if n == 0:
640 raise StatisticsError("no median for empty data")
641 if n % 2 == 1:
642 return data[n // 2]
643 else:
644 return data[n // 2 - 1]
645
646
647 def median_high(data):
648 """Return the high median of data.
649
650 When the number of data points is odd, the middle value is returned.
651 When it is even, the larger of the two middle values is returned.
652
653 >>> median_high([1, 3, 5])
654 3
655 >>> median_high([1, 3, 5, 7])
656 5
657
658 """
659 data = sorted(data)
660 n = len(data)
661 if n == 0:
662 raise StatisticsError("no median for empty data")
663 return data[n // 2]
664
665
666 def median_grouped(data, interval=1.0):
667 """Estimates the median for numeric data binned around the midpoints
668 of consecutive, fixed-width intervals.
669
670 The *data* can be any iterable of numeric data with each value being
671 exactly the midpoint of a bin. At least one value must be present.
672
673 The *interval* is width of each bin.
674
675 For example, demographic information may have been summarized into
676 consecutive ten-year age groups with each group being represented
677 by the 5-year midpoints of the intervals:
678
679 >>> demographics = Counter({
680 ... 25: 172, # 20 to 30 years old
681 ... 35: 484, # 30 to 40 years old
682 ... 45: 387, # 40 to 50 years old
683 ... 55: 22, # 50 to 60 years old
684 ... 65: 6, # 60 to 70 years old
685 ... })
686
687 The 50th percentile (median) is the 536th person out of the 1071
688 member cohort. That person is in the 30 to 40 year old age group.
689
690 The regular median() function would assume that everyone in the
691 tricenarian age group was exactly 35 years old. A more tenable
692 assumption is that the 484 members of that age group are evenly
693 distributed between 30 and 40. For that, we use median_grouped().
694
695 >>> data = list(demographics.elements())
696 >>> median(data)
697 35
698 >>> round(median_grouped(data, interval=10), 1)
699 37.5
700
701 The caller is responsible for making sure the data points are separated
702 by exact multiples of *interval*. This is essential for getting a
703 correct result. The function does not check this precondition.
704
705 Inputs may be any numeric type that can be coerced to a float during
706 the interpolation step.
707
708 """
709 data = sorted(data)
710 n = len(data)
711 if not n:
712 raise StatisticsError("no median for empty data")
713
714 # Find the value at the midpoint. Remember this corresponds to the
715 # midpoint of the class interval.
716 x = data[n // 2]
717
718 # Using O(log n) bisection, find where all the x values occur in the data.
719 # All x will lie within data[i:j].
720 i = bisect_left(data, x)
721 j = bisect_right(data, x, lo=i)
722
723 # Coerce to floats, raising a TypeError if not possible
724 try:
725 interval = float(interval)
726 x = float(x)
727 except ValueError:
728 raise TypeError(f'Value cannot be converted to a float')
729
730 # Interpolate the median using the formula found at:
731 # https://www.cuemath.com/data/median-of-grouped-data/
732 L = x - interval / 2.0 # Lower limit of the median interval
733 cf = i # Cumulative frequency of the preceding interval
734 f = j - i # Number of elements in the median internal
735 return L + interval * (n / 2 - cf) / f
736
737
738 def mode(data):
739 """Return the most common data point from discrete or nominal data.
740
741 ``mode`` assumes discrete data, and returns a single value. This is the
742 standard treatment of the mode as commonly taught in schools:
743
744 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
745 3
746
747 This also works with nominal (non-numeric) data:
748
749 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
750 'red'
751
752 If there are multiple modes with same frequency, return the first one
753 encountered:
754
755 >>> mode(['red', 'red', 'green', 'blue', 'blue'])
756 'red'
757
758 If *data* is empty, ``mode``, raises StatisticsError.
759
760 """
761 pairs = Counter(iter(data)).most_common(1)
762 try:
763 return pairs[0][0]
764 except IndexError:
765 raise StatisticsError('no mode for empty data') from None
766
767
768 def multimode(data):
769 """Return a list of the most frequently occurring values.
770
771 Will return more than one result if there are multiple modes
772 or an empty list if *data* is empty.
773
774 >>> multimode('aabbbbbbbbcc')
775 ['b']
776 >>> multimode('aabbbbccddddeeffffgg')
777 ['b', 'd', 'f']
778 >>> multimode('')
779 []
780 """
781 counts = Counter(iter(data))
782 if not counts:
783 return []
784 maxcount = max(counts.values())
785 return [value for value, count in counts.items() if count == maxcount]
786
787
788 # Notes on methods for computing quantiles
789 # ----------------------------------------
790 #
791 # There is no one perfect way to compute quantiles. Here we offer
792 # two methods that serve common needs. Most other packages
793 # surveyed offered at least one or both of these two, making them
794 # "standard" in the sense of "widely-adopted and reproducible".
795 # They are also easy to explain, easy to compute manually, and have
796 # straight-forward interpretations that aren't surprising.
797
798 # The default method is known as "R6", "PERCENTILE.EXC", or "expected
799 # value of rank order statistics". The alternative method is known as
800 # "R7", "PERCENTILE.INC", or "mode of rank order statistics".
801
802 # For sample data where there is a positive probability for values
803 # beyond the range of the data, the R6 exclusive method is a
804 # reasonable choice. Consider a random sample of nine values from a
805 # population with a uniform distribution from 0.0 to 1.0. The
806 # distribution of the third ranked sample point is described by
807 # betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
808 # mean=0.300. Only the latter (which corresponds with R6) gives the
809 # desired cut point with 30% of the population falling below that
810 # value, making it comparable to a result from an inv_cdf() function.
811 # The R6 exclusive method is also idempotent.
812
813 # For describing population data where the end points are known to
814 # be included in the data, the R7 inclusive method is a reasonable
815 # choice. Instead of the mean, it uses the mode of the beta
816 # distribution for the interior points. Per Hyndman & Fan, "One nice
817 # property is that the vertices of Q7(p) divide the range into n - 1
818 # intervals, and exactly 100p% of the intervals lie to the left of
819 # Q7(p) and 100(1 - p)% of the intervals lie to the right of Q7(p)."
820
821 # If needed, other methods could be added. However, for now, the
822 # position is that fewer options make for easier choices and that
823 # external packages can be used for anything more advanced.
824
825 def quantiles(data, *, n=4, method='exclusive'):
826 """Divide *data* into *n* continuous intervals with equal probability.
827
828 Returns a list of (n - 1) cut points separating the intervals.
829
830 Set *n* to 4 for quartiles (the default). Set *n* to 10 for deciles.
831 Set *n* to 100 for percentiles which gives the 99 cuts points that
832 separate *data* in to 100 equal sized groups.
833
834 The *data* can be any iterable containing sample.
835 The cut points are linearly interpolated between data points.
836
837 If *method* is set to *inclusive*, *data* is treated as population
838 data. The minimum value is treated as the 0th percentile and the
839 maximum value is treated as the 100th percentile.
840 """
841 if n < 1:
842 raise StatisticsError('n must be at least 1')
843 data = sorted(data)
844 ld = len(data)
845 if ld < 2:
846 raise StatisticsError('must have at least two data points')
847 if method == 'inclusive':
848 m = ld - 1
849 result = []
850 for i in range(1, n):
851 j, delta = divmod(i * m, n)
852 interpolated = (data[j] * (n - delta) + data[j + 1] * delta) / n
853 result.append(interpolated)
854 return result
855 if method == 'exclusive':
856 m = ld + 1
857 result = []
858 for i in range(1, n):
859 j = i * m // n # rescale i to m/n
860 j = 1 if j < 1 else ld-1 if j > ld-1 else j # clamp to 1 .. ld-1
861 delta = i*m - j*n # exact integer math
862 interpolated = (data[j - 1] * (n - delta) + data[j] * delta) / n
863 result.append(interpolated)
864 return result
865 raise ValueError(f'Unknown method: {method!r}')
866
867
868 # === Measures of spread ===
869
870 # See http://mathworld.wolfram.com/Variance.html
871 # http://mathworld.wolfram.com/SampleVariance.html
872
873
874 def variance(data, xbar=None):
875 """Return the sample variance of data.
876
877 data should be an iterable of Real-valued numbers, with at least two
878 values. The optional argument xbar, if given, should be the mean of
879 the data. If it is missing or None, the mean is automatically calculated.
880
881 Use this function when your data is a sample from a population. To
882 calculate the variance from the entire population, see ``pvariance``.
883
884 Examples:
885
886 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
887 >>> variance(data)
888 1.3720238095238095
889
890 If you have already calculated the mean of your data, you can pass it as
891 the optional second argument ``xbar`` to avoid recalculating it:
892
893 >>> m = mean(data)
894 >>> variance(data, m)
895 1.3720238095238095
896
897 This function does not check that ``xbar`` is actually the mean of
898 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
899 impossible results.
900
901 Decimals and Fractions are supported:
902
903 >>> from decimal import Decimal as D
904 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
905 Decimal('31.01875')
906
907 >>> from fractions import Fraction as F
908 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
909 Fraction(67, 108)
910
911 """
912 T, ss, c, n = _ss(data, xbar)
913 if n < 2:
914 raise StatisticsError('variance requires at least two data points')
915 return _convert(ss / (n - 1), T)
916
917
918 def pvariance(data, mu=None):
919 """Return the population variance of ``data``.
920
921 data should be a sequence or iterable of Real-valued numbers, with at least one
922 value. The optional argument mu, if given, should be the mean of
923 the data. If it is missing or None, the mean is automatically calculated.
924
925 Use this function to calculate the variance from the entire population.
926 To estimate the variance from a sample, the ``variance`` function is
927 usually a better choice.
928
929 Examples:
930
931 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
932 >>> pvariance(data)
933 1.25
934
935 If you have already calculated the mean of the data, you can pass it as
936 the optional second argument to avoid recalculating it:
937
938 >>> mu = mean(data)
939 >>> pvariance(data, mu)
940 1.25
941
942 Decimals and Fractions are supported:
943
944 >>> from decimal import Decimal as D
945 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
946 Decimal('24.815')
947
948 >>> from fractions import Fraction as F
949 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
950 Fraction(13, 72)
951
952 """
953 T, ss, c, n = _ss(data, mu)
954 if n < 1:
955 raise StatisticsError('pvariance requires at least one data point')
956 return _convert(ss / n, T)
957
958
959 def stdev(data, xbar=None):
960 """Return the square root of the sample variance.
961
962 See ``variance`` for arguments and other details.
963
964 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
965 1.0810874155219827
966
967 """
968 T, ss, c, n = _ss(data, xbar)
969 if n < 2:
970 raise StatisticsError('stdev requires at least two data points')
971 mss = ss / (n - 1)
972 if issubclass(T, Decimal):
973 return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
974 return _float_sqrt_of_frac(mss.numerator, mss.denominator)
975
976
977 def pstdev(data, mu=None):
978 """Return the square root of the population variance.
979
980 See ``pvariance`` for arguments and other details.
981
982 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
983 0.986893273527251
984
985 """
986 T, ss, c, n = _ss(data, mu)
987 if n < 1:
988 raise StatisticsError('pstdev requires at least one data point')
989 mss = ss / n
990 if issubclass(T, Decimal):
991 return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
992 return _float_sqrt_of_frac(mss.numerator, mss.denominator)
993
994
995 def _mean_stdev(data):
996 """In one pass, compute the mean and sample standard deviation as floats."""
997 T, ss, xbar, n = _ss(data)
998 if n < 2:
999 raise StatisticsError('stdev requires at least two data points')
1000 mss = ss / (n - 1)
1001 try:
1002 return float(xbar), _float_sqrt_of_frac(mss.numerator, mss.denominator)
1003 except AttributeError:
1004 # Handle Nans and Infs gracefully
1005 return float(xbar), float(xbar) / float(ss)
1006
1007
1008 # === Statistics for relations between two inputs ===
1009
1010 # See https://en.wikipedia.org/wiki/Covariance
1011 # https://en.wikipedia.org/wiki/Pearson_correlation_coefficient
1012 # https://en.wikipedia.org/wiki/Simple_linear_regression
1013
1014
1015 def covariance(x, y, /):
1016 """Covariance
1017
1018 Return the sample covariance of two inputs *x* and *y*. Covariance
1019 is a measure of the joint variability of two inputs.
1020
1021 >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
1022 >>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
1023 >>> covariance(x, y)
1024 0.75
1025 >>> z = [9, 8, 7, 6, 5, 4, 3, 2, 1]
1026 >>> covariance(x, z)
1027 -7.5
1028 >>> covariance(z, x)
1029 -7.5
1030
1031 """
1032 n = len(x)
1033 if len(y) != n:
1034 raise StatisticsError('covariance requires that both inputs have same number of data points')
1035 if n < 2:
1036 raise StatisticsError('covariance requires at least two data points')
1037 xbar = fsum(x) / n
1038 ybar = fsum(y) / n
1039 sxy = sumprod((xi - xbar for xi in x), (yi - ybar for yi in y))
1040 return sxy / (n - 1)
1041
1042
1043 def correlation(x, y, /, *, method='linear'):
1044 """Pearson's correlation coefficient
1045
1046 Return the Pearson's correlation coefficient for two inputs. Pearson's
1047 correlation coefficient *r* takes values between -1 and +1. It measures
1048 the strength and direction of a linear relationship.
1049
1050 >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
1051 >>> y = [9, 8, 7, 6, 5, 4, 3, 2, 1]
1052 >>> correlation(x, x)
1053 1.0
1054 >>> correlation(x, y)
1055 -1.0
1056
1057 If *method* is "ranked", computes Spearman's rank correlation coefficient
1058 for two inputs. The data is replaced by ranks. Ties are averaged
1059 so that equal values receive the same rank. The resulting coefficient
1060 measures the strength of a monotonic relationship.
1061
1062 Spearman's rank correlation coefficient is appropriate for ordinal
1063 data or for continuous data that doesn't meet the linear proportion
1064 requirement for Pearson's correlation coefficient.
1065 """
1066 n = len(x)
1067 if len(y) != n:
1068 raise StatisticsError('correlation requires that both inputs have same number of data points')
1069 if n < 2:
1070 raise StatisticsError('correlation requires at least two data points')
1071 if method not in {'linear', 'ranked'}:
1072 raise ValueError(f'Unknown method: {method!r}')
1073 if method == 'ranked':
1074 start = (n - 1) / -2 # Center rankings around zero
1075 x = _rank(x, start=start)
1076 y = _rank(y, start=start)
1077 else:
1078 xbar = fsum(x) / n
1079 ybar = fsum(y) / n
1080 x = [xi - xbar for xi in x]
1081 y = [yi - ybar for yi in y]
1082 sxy = sumprod(x, y)
1083 sxx = sumprod(x, x)
1084 syy = sumprod(y, y)
1085 try:
1086 return sxy / sqrt(sxx * syy)
1087 except ZeroDivisionError:
1088 raise StatisticsError('at least one of the inputs is constant')
1089
1090
1091 LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept'))
1092
1093
1094 def linear_regression(x, y, /, *, proportional=False):
1095 """Slope and intercept for simple linear regression.
1096
1097 Return the slope and intercept of simple linear regression
1098 parameters estimated using ordinary least squares. Simple linear
1099 regression describes relationship between an independent variable
1100 *x* and a dependent variable *y* in terms of a linear function:
1101
1102 y = slope * x + intercept + noise
1103
1104 where *slope* and *intercept* are the regression parameters that are
1105 estimated, and noise represents the variability of the data that was
1106 not explained by the linear regression (it is equal to the
1107 difference between predicted and actual values of the dependent
1108 variable).
1109
1110 The parameters are returned as a named tuple.
1111
1112 >>> x = [1, 2, 3, 4, 5]
1113 >>> noise = NormalDist().samples(5, seed=42)
1114 >>> y = [3 * x[i] + 2 + noise[i] for i in range(5)]
1115 >>> linear_regression(x, y) #doctest: +ELLIPSIS
1116 LinearRegression(slope=3.09078914170..., intercept=1.75684970486...)
1117
1118 If *proportional* is true, the independent variable *x* and the
1119 dependent variable *y* are assumed to be directly proportional.
1120 The data is fit to a line passing through the origin.
1121
1122 Since the *intercept* will always be 0.0, the underlying linear
1123 function simplifies to:
1124
1125 y = slope * x + noise
1126
1127 >>> y = [3 * x[i] + noise[i] for i in range(5)]
1128 >>> linear_regression(x, y, proportional=True) #doctest: +ELLIPSIS
1129 LinearRegression(slope=3.02447542484..., intercept=0.0)
1130
1131 """
1132 n = len(x)
1133 if len(y) != n:
1134 raise StatisticsError('linear regression requires that both inputs have same number of data points')
1135 if n < 2:
1136 raise StatisticsError('linear regression requires at least two data points')
1137 if not proportional:
1138 xbar = fsum(x) / n
1139 ybar = fsum(y) / n
1140 x = [xi - xbar for xi in x] # List because used three times below
1141 y = (yi - ybar for yi in y) # Generator because only used once below
1142 sxy = sumprod(x, y) + 0.0 # Add zero to coerce result to a float
1143 sxx = sumprod(x, x)
1144 try:
1145 slope = sxy / sxx # equivalent to: covariance(x, y) / variance(x)
1146 except ZeroDivisionError:
1147 raise StatisticsError('x is constant')
1148 intercept = 0.0 if proportional else ybar - slope * xbar
1149 return LinearRegression(slope=slope, intercept=intercept)
1150
1151
1152 ## Normal Distribution #####################################################
1153
1154
1155 def _normal_dist_inv_cdf(p, mu, sigma):
1156 # There is no closed-form solution to the inverse CDF for the normal
1157 # distribution, so we use a rational approximation instead:
1158 # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
1159 # Normal Distribution". Applied Statistics. Blackwell Publishing. 37
1160 # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
1161 q = p - 0.5
1162 if fabs(q) <= 0.425:
1163 r = 0.180625 - q * q
1164 # Hash sum: 55.88319_28806_14901_4439
1165 num = (((((((2.50908_09287_30122_6727e+3 * r +
1166 3.34305_75583_58812_8105e+4) * r +
1167 6.72657_70927_00870_0853e+4) * r +
1168 4.59219_53931_54987_1457e+4) * r +
1169 1.37316_93765_50946_1125e+4) * r +
1170 1.97159_09503_06551_4427e+3) * r +
1171 1.33141_66789_17843_7745e+2) * r +
1172 3.38713_28727_96366_6080e+0) * q
1173 den = (((((((5.22649_52788_52854_5610e+3 * r +
1174 2.87290_85735_72194_2674e+4) * r +
1175 3.93078_95800_09271_0610e+4) * r +
1176 2.12137_94301_58659_5867e+4) * r +
1177 5.39419_60214_24751_1077e+3) * r +
1178 6.87187_00749_20579_0830e+2) * r +
1179 4.23133_30701_60091_1252e+1) * r +
1180 1.0)
1181 x = num / den
1182 return mu + (x * sigma)
1183 r = p if q <= 0.0 else 1.0 - p
1184 r = sqrt(-log(r))
1185 if r <= 5.0:
1186 r = r - 1.6
1187 # Hash sum: 49.33206_50330_16102_89036
1188 num = (((((((7.74545_01427_83414_07640e-4 * r +
1189 2.27238_44989_26918_45833e-2) * r +
1190 2.41780_72517_74506_11770e-1) * r +
1191 1.27045_82524_52368_38258e+0) * r +
1192 3.64784_83247_63204_60504e+0) * r +
1193 5.76949_72214_60691_40550e+0) * r +
1194 4.63033_78461_56545_29590e+0) * r +
1195 1.42343_71107_49683_57734e+0)
1196 den = (((((((1.05075_00716_44416_84324e-9 * r +
1197 5.47593_80849_95344_94600e-4) * r +
1198 1.51986_66563_61645_71966e-2) * r +
1199 1.48103_97642_74800_74590e-1) * r +
1200 6.89767_33498_51000_04550e-1) * r +
1201 1.67638_48301_83803_84940e+0) * r +
1202 2.05319_16266_37758_82187e+0) * r +
1203 1.0)
1204 else:
1205 r = r - 5.0
1206 # Hash sum: 47.52583_31754_92896_71629
1207 num = (((((((2.01033_43992_92288_13265e-7 * r +
1208 2.71155_55687_43487_57815e-5) * r +
1209 1.24266_09473_88078_43860e-3) * r +
1210 2.65321_89526_57612_30930e-2) * r +
1211 2.96560_57182_85048_91230e-1) * r +
1212 1.78482_65399_17291_33580e+0) * r +
1213 5.46378_49111_64114_36990e+0) * r +
1214 6.65790_46435_01103_77720e+0)
1215 den = (((((((2.04426_31033_89939_78564e-15 * r +
1216 1.42151_17583_16445_88870e-7) * r +
1217 1.84631_83175_10054_68180e-5) * r +
1218 7.86869_13114_56132_59100e-4) * r +
1219 1.48753_61290_85061_48525e-2) * r +
1220 1.36929_88092_27358_05310e-1) * r +
1221 5.99832_20655_58879_37690e-1) * r +
1222 1.0)
1223 x = num / den
1224 if q < 0.0:
1225 x = -x
1226 return mu + (x * sigma)
1227
1228
1229 # If available, use C implementation
1230 try:
1231 from _statistics import _normal_dist_inv_cdf
1232 except ImportError:
1233 pass
1234
1235
1236 class ESC[4;38;5;81mNormalDist:
1237 "Normal distribution of a random variable"
1238 # https://en.wikipedia.org/wiki/Normal_distribution
1239 # https://en.wikipedia.org/wiki/Variance#Properties
1240
1241 __slots__ = {
1242 '_mu': 'Arithmetic mean of a normal distribution',
1243 '_sigma': 'Standard deviation of a normal distribution',
1244 }
1245
1246 def __init__(self, mu=0.0, sigma=1.0):
1247 "NormalDist where mu is the mean and sigma is the standard deviation."
1248 if sigma < 0.0:
1249 raise StatisticsError('sigma must be non-negative')
1250 self._mu = float(mu)
1251 self._sigma = float(sigma)
1252
1253 @classmethod
1254 def from_samples(cls, data):
1255 "Make a normal distribution instance from sample data."
1256 return cls(*_mean_stdev(data))
1257
1258 def samples(self, n, *, seed=None):
1259 "Generate *n* samples for a given mean and standard deviation."
1260 gauss = random.gauss if seed is None else random.Random(seed).gauss
1261 mu, sigma = self._mu, self._sigma
1262 return [gauss(mu, sigma) for _ in repeat(None, n)]
1263
1264 def pdf(self, x):
1265 "Probability density function. P(x <= X < x+dx) / dx"
1266 variance = self._sigma * self._sigma
1267 if not variance:
1268 raise StatisticsError('pdf() not defined when sigma is zero')
1269 diff = x - self._mu
1270 return exp(diff * diff / (-2.0 * variance)) / sqrt(tau * variance)
1271
1272 def cdf(self, x):
1273 "Cumulative distribution function. P(X <= x)"
1274 if not self._sigma:
1275 raise StatisticsError('cdf() not defined when sigma is zero')
1276 return 0.5 * (1.0 + erf((x - self._mu) / (self._sigma * _SQRT2)))
1277
1278 def inv_cdf(self, p):
1279 """Inverse cumulative distribution function. x : P(X <= x) = p
1280
1281 Finds the value of the random variable such that the probability of
1282 the variable being less than or equal to that value equals the given
1283 probability.
1284
1285 This function is also called the percent point function or quantile
1286 function.
1287 """
1288 if p <= 0.0 or p >= 1.0:
1289 raise StatisticsError('p must be in the range 0.0 < p < 1.0')
1290 return _normal_dist_inv_cdf(p, self._mu, self._sigma)
1291
1292 def quantiles(self, n=4):
1293 """Divide into *n* continuous intervals with equal probability.
1294
1295 Returns a list of (n - 1) cut points separating the intervals.
1296
1297 Set *n* to 4 for quartiles (the default). Set *n* to 10 for deciles.
1298 Set *n* to 100 for percentiles which gives the 99 cuts points that
1299 separate the normal distribution in to 100 equal sized groups.
1300 """
1301 return [self.inv_cdf(i / n) for i in range(1, n)]
1302
1303 def overlap(self, other):
1304 """Compute the overlapping coefficient (OVL) between two normal distributions.
1305
1306 Measures the agreement between two normal probability distributions.
1307 Returns a value between 0.0 and 1.0 giving the overlapping area in
1308 the two underlying probability density functions.
1309
1310 >>> N1 = NormalDist(2.4, 1.6)
1311 >>> N2 = NormalDist(3.2, 2.0)
1312 >>> N1.overlap(N2)
1313 0.8035050657330205
1314 """
1315 # See: "The overlapping coefficient as a measure of agreement between
1316 # probability distributions and point estimation of the overlap of two
1317 # normal densities" -- Henry F. Inman and Edwin L. Bradley Jr
1318 # http://dx.doi.org/10.1080/03610928908830127
1319 if not isinstance(other, NormalDist):
1320 raise TypeError('Expected another NormalDist instance')
1321 X, Y = self, other
1322 if (Y._sigma, Y._mu) < (X._sigma, X._mu): # sort to assure commutativity
1323 X, Y = Y, X
1324 X_var, Y_var = X.variance, Y.variance
1325 if not X_var or not Y_var:
1326 raise StatisticsError('overlap() not defined when sigma is zero')
1327 dv = Y_var - X_var
1328 dm = fabs(Y._mu - X._mu)
1329 if not dv:
1330 return 1.0 - erf(dm / (2.0 * X._sigma * _SQRT2))
1331 a = X._mu * Y_var - Y._mu * X_var
1332 b = X._sigma * Y._sigma * sqrt(dm * dm + dv * log(Y_var / X_var))
1333 x1 = (a + b) / dv
1334 x2 = (a - b) / dv
1335 return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
1336
1337 def zscore(self, x):
1338 """Compute the Standard Score. (x - mean) / stdev
1339
1340 Describes *x* in terms of the number of standard deviations
1341 above or below the mean of the normal distribution.
1342 """
1343 # https://www.statisticshowto.com/probability-and-statistics/z-score/
1344 if not self._sigma:
1345 raise StatisticsError('zscore() not defined when sigma is zero')
1346 return (x - self._mu) / self._sigma
1347
1348 @property
1349 def mean(self):
1350 "Arithmetic mean of the normal distribution."
1351 return self._mu
1352
1353 @property
1354 def median(self):
1355 "Return the median of the normal distribution"
1356 return self._mu
1357
1358 @property
1359 def mode(self):
1360 """Return the mode of the normal distribution
1361
1362 The mode is the value x where which the probability density
1363 function (pdf) takes its maximum value.
1364 """
1365 return self._mu
1366
1367 @property
1368 def stdev(self):
1369 "Standard deviation of the normal distribution."
1370 return self._sigma
1371
1372 @property
1373 def variance(self):
1374 "Square of the standard deviation."
1375 return self._sigma * self._sigma
1376
1377 def __add__(x1, x2):
1378 """Add a constant or another NormalDist instance.
1379
1380 If *other* is a constant, translate mu by the constant,
1381 leaving sigma unchanged.
1382
1383 If *other* is a NormalDist, add both the means and the variances.
1384 Mathematically, this works only if the two distributions are
1385 independent or if they are jointly normally distributed.
1386 """
1387 if isinstance(x2, NormalDist):
1388 return NormalDist(x1._mu + x2._mu, hypot(x1._sigma, x2._sigma))
1389 return NormalDist(x1._mu + x2, x1._sigma)
1390
1391 def __sub__(x1, x2):
1392 """Subtract a constant or another NormalDist instance.
1393
1394 If *other* is a constant, translate by the constant mu,
1395 leaving sigma unchanged.
1396
1397 If *other* is a NormalDist, subtract the means and add the variances.
1398 Mathematically, this works only if the two distributions are
1399 independent or if they are jointly normally distributed.
1400 """
1401 if isinstance(x2, NormalDist):
1402 return NormalDist(x1._mu - x2._mu, hypot(x1._sigma, x2._sigma))
1403 return NormalDist(x1._mu - x2, x1._sigma)
1404
1405 def __mul__(x1, x2):
1406 """Multiply both mu and sigma by a constant.
1407
1408 Used for rescaling, perhaps to change measurement units.
1409 Sigma is scaled with the absolute value of the constant.
1410 """
1411 return NormalDist(x1._mu * x2, x1._sigma * fabs(x2))
1412
1413 def __truediv__(x1, x2):
1414 """Divide both mu and sigma by a constant.
1415
1416 Used for rescaling, perhaps to change measurement units.
1417 Sigma is scaled with the absolute value of the constant.
1418 """
1419 return NormalDist(x1._mu / x2, x1._sigma / fabs(x2))
1420
1421 def __pos__(x1):
1422 "Return a copy of the instance."
1423 return NormalDist(x1._mu, x1._sigma)
1424
1425 def __neg__(x1):
1426 "Negates mu while keeping sigma the same."
1427 return NormalDist(-x1._mu, x1._sigma)
1428
1429 __radd__ = __add__
1430
1431 def __rsub__(x1, x2):
1432 "Subtract a NormalDist from a constant or another NormalDist."
1433 return -(x1 - x2)
1434
1435 __rmul__ = __mul__
1436
1437 def __eq__(x1, x2):
1438 "Two NormalDist objects are equal if their mu and sigma are both equal."
1439 if not isinstance(x2, NormalDist):
1440 return NotImplemented
1441 return x1._mu == x2._mu and x1._sigma == x2._sigma
1442
1443 def __hash__(self):
1444 "NormalDist objects hash equal if their mu and sigma are both equal."
1445 return hash((self._mu, self._sigma))
1446
1447 def __repr__(self):
1448 return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})'
1449
1450 def __getstate__(self):
1451 return self._mu, self._sigma
1452
1453 def __setstate__(self, state):
1454 self._mu, self._sigma = state