1 """Python implementations of some algorithms for use by longobject.c.
2 The goal is to provide asymptotically faster algorithms that can be
3 used for operations on integers with many digits. In those cases, the
4 performance overhead of the Python implementation is not significant
5 since the asymptotic behavior is what dominates runtime. Functions
6 provided by this module should be considered private and not part of any
7 public API.
8
9 Note: for ease of maintainability, please prefer clear code and avoid
10 "micro-optimizations". This module will only be imported and used for
11 integers with a huge number of digits. Saving a few microseconds with
12 tricky or non-obvious code is not worth it. For people looking for
13 maximum performance, they should use something like gmpy2."""
14
15 import re
16 import decimal
17
18
19 def int_to_decimal(n):
20 """Asymptotically fast conversion of an 'int' to Decimal."""
21
22 # Function due to Tim Peters. See GH issue #90716 for details.
23 # https://github.com/python/cpython/issues/90716
24 #
25 # The implementation in longobject.c of base conversion algorithms
26 # between power-of-2 and non-power-of-2 bases are quadratic time.
27 # This function implements a divide-and-conquer algorithm that is
28 # faster for large numbers. Builds an equal decimal.Decimal in a
29 # "clever" recursive way. If we want a string representation, we
30 # apply str to _that_.
31
32 D = decimal.Decimal
33 D2 = D(2)
34
35 BITLIM = 128
36
37 mem = {}
38
39 def w2pow(w):
40 """Return D(2)**w and store the result. Also possibly save some
41 intermediate results. In context, these are likely to be reused
42 across various levels of the conversion to Decimal."""
43 if (result := mem.get(w)) is None:
44 if w <= BITLIM:
45 result = D2**w
46 elif w - 1 in mem:
47 result = (t := mem[w - 1]) + t
48 else:
49 w2 = w >> 1
50 # If w happens to be odd, w-w2 is one larger then w2
51 # now. Recurse on the smaller first (w2), so that it's
52 # in the cache and the larger (w-w2) can be handled by
53 # the cheaper `w-1 in mem` branch instead.
54 result = w2pow(w2) * w2pow(w - w2)
55 mem[w] = result
56 return result
57
58 def inner(n, w):
59 if w <= BITLIM:
60 return D(n)
61 w2 = w >> 1
62 hi = n >> w2
63 lo = n - (hi << w2)
64 return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
65
66 with decimal.localcontext() as ctx:
67 ctx.prec = decimal.MAX_PREC
68 ctx.Emax = decimal.MAX_EMAX
69 ctx.Emin = decimal.MIN_EMIN
70 ctx.traps[decimal.Inexact] = 1
71
72 if n < 0:
73 negate = True
74 n = -n
75 else:
76 negate = False
77 result = inner(n, n.bit_length())
78 if negate:
79 result = -result
80 return result
81
82
83 def int_to_decimal_string(n):
84 """Asymptotically fast conversion of an 'int' to a decimal string."""
85 return str(int_to_decimal(n))
86
87
88 def _str_to_int_inner(s):
89 """Asymptotically fast conversion of a 'str' to an 'int'."""
90
91 # Function due to Bjorn Martinsson. See GH issue #90716 for details.
92 # https://github.com/python/cpython/issues/90716
93 #
94 # The implementation in longobject.c of base conversion algorithms
95 # between power-of-2 and non-power-of-2 bases are quadratic time.
96 # This function implements a divide-and-conquer algorithm making use
97 # of Python's built in big int multiplication. Since Python uses the
98 # Karatsuba algorithm for multiplication, the time complexity
99 # of this function is O(len(s)**1.58).
100
101 DIGLIM = 2048
102
103 mem = {}
104
105 def w5pow(w):
106 """Return 5**w and store the result.
107 Also possibly save some intermediate results. In context, these
108 are likely to be reused across various levels of the conversion
109 to 'int'.
110 """
111 if (result := mem.get(w)) is None:
112 if w <= DIGLIM:
113 result = 5**w
114 elif w - 1 in mem:
115 result = mem[w - 1] * 5
116 else:
117 w2 = w >> 1
118 # If w happens to be odd, w-w2 is one larger then w2
119 # now. Recurse on the smaller first (w2), so that it's
120 # in the cache and the larger (w-w2) can be handled by
121 # the cheaper `w-1 in mem` branch instead.
122 result = w5pow(w2) * w5pow(w - w2)
123 mem[w] = result
124 return result
125
126 def inner(a, b):
127 if b - a <= DIGLIM:
128 return int(s[a:b])
129 mid = (a + b + 1) >> 1
130 return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
131
132 return inner(0, len(s))
133
134
135 def int_from_string(s):
136 """Asymptotically fast version of PyLong_FromString(), conversion
137 of a string of decimal digits into an 'int'."""
138 # PyLong_FromString() has already removed leading +/-, checked for invalid
139 # use of underscore characters, checked that string consists of only digits
140 # and underscores, and stripped leading whitespace. The input can still
141 # contain underscores and have trailing whitespace.
142 s = s.rstrip().replace('_', '')
143 return _str_to_int_inner(s)
144
145
146 def str_to_int(s):
147 """Asymptotically fast version of decimal string to 'int' conversion."""
148 # FIXME: this doesn't support the full syntax that int() supports.
149 m = re.match(r'\s*([+-]?)([0-9_]+)\s*', s)
150 if not m:
151 raise ValueError('invalid literal for int() with base 10')
152 v = int_from_string(m.group(2))
153 if m.group(1) == '-':
154 v = -v
155 return v
156
157
158 # Fast integer division, based on code from Mark Dickinson, fast_div.py
159 # GH-47701. Additional refinements and optimizations by Bjorn Martinsson. The
160 # algorithm is due to Burnikel and Ziegler, in their paper "Fast Recursive
161 # Division".
162
163 _DIV_LIMIT = 4000
164
165
166 def _div2n1n(a, b, n):
167 """Divide a 2n-bit nonnegative integer a by an n-bit positive integer
168 b, using a recursive divide-and-conquer algorithm.
169
170 Inputs:
171 n is a positive integer
172 b is a positive integer with exactly n bits
173 a is a nonnegative integer such that a < 2**n * b
174
175 Output:
176 (q, r) such that a = b*q+r and 0 <= r < b.
177
178 """
179 if a.bit_length() - n <= _DIV_LIMIT:
180 return divmod(a, b)
181 pad = n & 1
182 if pad:
183 a <<= 1
184 b <<= 1
185 n += 1
186 half_n = n >> 1
187 mask = (1 << half_n) - 1
188 b1, b2 = b >> half_n, b & mask
189 q1, r = _div3n2n(a >> n, (a >> half_n) & mask, b, b1, b2, half_n)
190 q2, r = _div3n2n(r, a & mask, b, b1, b2, half_n)
191 if pad:
192 r >>= 1
193 return q1 << half_n | q2, r
194
195
196 def _div3n2n(a12, a3, b, b1, b2, n):
197 """Helper function for _div2n1n; not intended to be called directly."""
198 if a12 >> n == b1:
199 q, r = (1 << n) - 1, a12 - (b1 << n) + b1
200 else:
201 q, r = _div2n1n(a12, b1, n)
202 r = (r << n | a3) - q * b2
203 while r < 0:
204 q -= 1
205 r += b
206 return q, r
207
208
209 def _int2digits(a, n):
210 """Decompose non-negative int a into base 2**n
211
212 Input:
213 a is a non-negative integer
214
215 Output:
216 List of the digits of a in base 2**n in little-endian order,
217 meaning the most significant digit is last. The most
218 significant digit is guaranteed to be non-zero.
219 If a is 0 then the output is an empty list.
220
221 """
222 a_digits = [0] * ((a.bit_length() + n - 1) // n)
223
224 def inner(x, L, R):
225 if L + 1 == R:
226 a_digits[L] = x
227 return
228 mid = (L + R) >> 1
229 shift = (mid - L) * n
230 upper = x >> shift
231 lower = x ^ (upper << shift)
232 inner(lower, L, mid)
233 inner(upper, mid, R)
234
235 if a:
236 inner(a, 0, len(a_digits))
237 return a_digits
238
239
240 def _digits2int(digits, n):
241 """Combine base-2**n digits into an int. This function is the
242 inverse of `_int2digits`. For more details, see _int2digits.
243 """
244
245 def inner(L, R):
246 if L + 1 == R:
247 return digits[L]
248 mid = (L + R) >> 1
249 shift = (mid - L) * n
250 return (inner(mid, R) << shift) + inner(L, mid)
251
252 return inner(0, len(digits)) if digits else 0
253
254
255 def _divmod_pos(a, b):
256 """Divide a non-negative integer a by a positive integer b, giving
257 quotient and remainder."""
258 # Use grade-school algorithm in base 2**n, n = nbits(b)
259 n = b.bit_length()
260 a_digits = _int2digits(a, n)
261
262 r = 0
263 q_digits = []
264 for a_digit in reversed(a_digits):
265 q_digit, r = _div2n1n((r << n) + a_digit, b, n)
266 q_digits.append(q_digit)
267 q_digits.reverse()
268 q = _digits2int(q_digits, n)
269 return q, r
270
271
272 def int_divmod(a, b):
273 """Asymptotically fast replacement for divmod, for 'int'.
274 Its time complexity is O(n**1.58), where n = #bits(a) + #bits(b).
275 """
276 if b == 0:
277 raise ZeroDivisionError
278 elif b < 0:
279 q, r = int_divmod(-a, -b)
280 return q, -r
281 elif a < 0:
282 q, r = int_divmod(~a, b)
283 return ~q, b + ~r
284 else:
285 return _divmod_pos(a, b)