diff --git a/Lib/statistics.py b/Lib/statistics.py index cf8eaa0a61e624..9f1efa21b15e3c 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -137,7 +137,7 @@ from itertools import groupby, repeat from bisect import bisect_left, bisect_right from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum -from operator import itemgetter, mul +from operator import mul from collections import Counter, namedtuple _SQRT2 = sqrt(2.0) @@ -248,6 +248,28 @@ def _exact_ratio(x): x is expected to be an int, Fraction, Decimal or float. """ + + # XXX We should revisit whether using fractions to accumulate exact + # ratios is the right way to go. + + # The integer ratios for binary floats can have numerators or + # denominators with over 300 decimal digits. The problem is more + # acute with decimal floats where the the default decimal context + # supports a huge range of exponents from Emin=-999999 to + # Emax=999999. When expanded with as_integer_ratio(), numbers like + # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large + # numerators or denominators that will slow computation. + + # When the integer ratios are accumulated as fractions, the size + # grows to cover the full range from the smallest magnitude to the + # largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300), + # has a 616 digit numerator. Likewise, + # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000')) + # has 10,003 digit numerator. + + # This doesn't seem to have been problem in practice, but it is a + # potential pitfall. + try: return x.as_integer_ratio() except AttributeError: @@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'): raise StatisticsError(errmsg) yield x -def _isqrt_frac_rto(n: int, m: int) -> float: + +def _integer_sqrt_of_frac_rto(n: int, m: int) -> int: """Square root of n/m, rounded to the nearest integer using round-to-odd.""" # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf a = math.isqrt(n // m) return a | (a*a*m != n) -# For 53 bit precision floats, the _sqrt_frac() shift is 109. -_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3 -def _sqrt_frac(n: int, m: int) -> float: +# For 53 bit precision floats, the bit width used in +# _float_sqrt_of_frac() is 109. +_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3 + + +def _float_sqrt_of_frac(n: int, m: int) -> float: """Square root of n/m as a float, correctly rounded.""" # See principle and proof sketch at: https://bugs.python.org/msg407078 - q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 + q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2 if q >= 0: - numerator = _isqrt_frac_rto(n, m << 2 * q) << q + numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q denominator = 1 else: - numerator = _isqrt_frac_rto(n << -2 * q, m) + numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m) denominator = 1 << -q return numerator / denominator # Convert to float +def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal: + """Square root of n/m as a Decimal, correctly rounded.""" + # Premise: For decimal, computing (n/m).sqrt() can be off + # by 1 ulp from the correctly rounded result. + # Method: Check the result, moving up or down a step if needed. + if n <= 0: + if not n: + return Decimal('0.0') + n, m = -n, -m + + root = (Decimal(n) / Decimal(m)).sqrt() + nr, dr = root.as_integer_ratio() + + plus = root.next_plus() + np, dp = plus.as_integer_ratio() + # test: n / m > ((root + plus) / 2) ** 2 + if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2: + return plus + + minus = root.next_minus() + nm, dm = minus.as_integer_ratio() + # test: n / m < ((root + minus) / 2) ** 2 + if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2: + return minus + + return root + + # === Measures of central tendency (averages) === def mean(data): @@ -869,7 +923,7 @@ def stdev(data, xbar=None): if hasattr(T, 'sqrt'): var = _convert(mss, T) return var.sqrt() - return _sqrt_frac(mss.numerator, mss.denominator) + return _float_sqrt_of_frac(mss.numerator, mss.denominator) def pstdev(data, mu=None): @@ -888,10 +942,9 @@ def pstdev(data, mu=None): raise StatisticsError('pstdev requires at least one data point') T, ss = _ss(data, mu) mss = ss / n - if hasattr(T, 'sqrt'): - var = _convert(mss, T) - return var.sqrt() - return _sqrt_frac(mss.numerator, mss.denominator) + if issubclass(T, Decimal): + return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) + return _float_sqrt_of_frac(mss.numerator, mss.denominator) # === Statistics for relations between two inputs === diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index 771a03e707ee01..bacb76a9b036bf 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -2164,9 +2164,9 @@ def test_center_not_at_mean(self): class TestSqrtHelpers(unittest.TestCase): - def test_isqrt_frac_rto(self): + def test_integer_sqrt_of_frac_rto(self): for n, m in itertools.product(range(100), range(1, 1000)): - r = statistics._isqrt_frac_rto(n, m) + r = statistics._integer_sqrt_of_frac_rto(n, m) self.assertIsInstance(r, int) if r*r*m == n: # Root is exact @@ -2177,7 +2177,7 @@ def test_isqrt_frac_rto(self): self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) @requires_IEEE_754 - def test_sqrt_frac(self): + def test_float_sqrt_of_frac(self): def is_root_correctly_rounded(x: Fraction, root: float) -> bool: if not x: @@ -2204,22 +2204,59 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool: denonimator: int = randrange(10 ** randrange(50)) + 1 with self.subTest(numerator=numerator, denonimator=denonimator): x: Fraction = Fraction(numerator, denonimator) - root: float = statistics._sqrt_frac(numerator, denonimator) + root: float = statistics._float_sqrt_of_frac(numerator, denonimator) self.assertTrue(is_root_correctly_rounded(x, root)) # Verify that corner cases and error handling match math.sqrt() - self.assertEqual(statistics._sqrt_frac(0, 1), 0.0) + self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0) with self.assertRaises(ValueError): - statistics._sqrt_frac(-1, 1) + statistics._float_sqrt_of_frac(-1, 1) with self.assertRaises(ValueError): - statistics._sqrt_frac(1, -1) + statistics._float_sqrt_of_frac(1, -1) # Error handling for zero denominator matches that for Fraction(1, 0) with self.assertRaises(ZeroDivisionError): - statistics._sqrt_frac(1, 0) + statistics._float_sqrt_of_frac(1, 0) # The result is well defined if both inputs are negative - self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0)) + self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1)) + + def test_decimal_sqrt_of_frac(self): + root: Decimal + numerator: int + denominator: int + + for root, numerator, denominator in [ + (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj + (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up + (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down + ]: + with decimal.localcontext(decimal.DefaultContext): + self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root) + + # Confirm expected root with a quad precision decimal computation + with decimal.localcontext(decimal.DefaultContext) as ctx: + ctx.prec *= 4 + high_prec_ratio = Decimal(numerator) / Decimal(denominator) + ctx.rounding = decimal.ROUND_05UP + high_prec_root = high_prec_ratio.sqrt() + with decimal.localcontext(decimal.DefaultContext): + target_root = +high_prec_root + self.assertEqual(root, target_root) + + # Verify that corner cases and error handling match Decimal.sqrt() + self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0) + with self.assertRaises(decimal.InvalidOperation): + statistics._decimal_sqrt_of_frac(-1, 1) + with self.assertRaises(decimal.InvalidOperation): + statistics._decimal_sqrt_of_frac(1, -1) + + # Error handling for zero denominator matches that for Fraction(1, 0) + with self.assertRaises(ZeroDivisionError): + statistics._decimal_sqrt_of_frac(1, 0) + + # The result is well defined if both inputs are negative + self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1)) class TestStdev(VarianceStdevMixin, NumericTestCase):
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: