1 /* Schoenhage's fast multiplication modulo 2^N+1.
2
3 Contributed by Paul Zimmermann.
4
5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY
6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST
7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8
9 Copyright 1998-2010, 2012, 2013, 2018, 2020, 2022 Free Software
10 Foundation, Inc.
11
12 This file is part of the GNU MP Library.
13
14 The GNU MP Library is free software; you can redistribute it and/or modify
15 it under the terms of either:
16
17 * the GNU Lesser General Public License as published by the Free
18 Software Foundation; either version 3 of the License, or (at your
19 option) any later version.
20
21 or
22
23 * the GNU General Public License as published by the Free Software
24 Foundation; either version 2 of the License, or (at your option) any
25 later version.
26
27 or both in parallel, as here.
28
29 The GNU MP Library is distributed in the hope that it will be useful, but
30 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
31 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
32 for more details.
33
34 You should have received copies of the GNU General Public License and the
35 GNU Lesser General Public License along with the GNU MP Library. If not,
36 see https://www.gnu.org/licenses/. */
37
38
39 /* References:
40
41 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
42 Strassen, Computing 7, p. 281-292, 1971.
43
44 Asymptotically fast algorithms for the numerical multiplication and division
45 of polynomials with complex coefficients, by Arnold Schoenhage, Computer
46 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
47
48 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
49 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
50
51 TODO:
52
53 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
54 Zimmermann.
55
56 It might be possible to avoid a small number of MPN_COPYs by using a
57 rotating temporary or two.
58
59 Cleanup and simplify the code!
60 */
61
62 #ifdef TRACE
63 #undef TRACE
64 #define TRACE(x) x
65 #include <stdio.h>
66 #else
67 #define TRACE(x)
68 #endif
69
70 #include "gmp-impl.h"
71
72 #ifdef WANT_ADDSUB
73 #include "generic/add_n_sub_n.c"
74 #define HAVE_NATIVE_mpn_add_n_sub_n 1
75 #endif
76
77 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
78 mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
79 mp_size_t, mp_size_t, int **, mp_ptr, int);
80 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
81 mp_size_t, mp_size_t, mp_size_t, mp_ptr);
82
83
84 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
85 We have sqr=0 if for a multiply, sqr=1 for a square.
86 There are three generations of this code; we keep the old ones as long as
87 some gmp-mparam.h is not updated. */
88
89
90 /*****************************************************************************/
91
92 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
93
94 #ifndef FFT_TABLE3_SIZE /* When tuning this is defined in gmp-impl.h */
95 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
96 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
97 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
98 #else
99 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
100 #endif
101 #endif
102 #endif
103
104 #ifndef FFT_TABLE3_SIZE
105 #define FFT_TABLE3_SIZE 200
106 #endif
107
108 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
109 {
110 MUL_FFT_TABLE3,
111 SQR_FFT_TABLE3
112 };
113
114 int
115 mpn_fft_best_k (mp_size_t n, int sqr)
116 {
117 const struct fft_table_nk *fft_tab, *tab;
118 mp_size_t tab_n, thres;
119 int last_k;
120
121 fft_tab = mpn_fft_table3[sqr];
122 last_k = fft_tab->k;
123 for (tab = fft_tab + 1; ; tab++)
124 {
125 tab_n = tab->n;
126 thres = tab_n << last_k;
127 if (n <= thres)
128 break;
129 last_k = tab->k;
130 }
131 return last_k;
132 }
133
134 #define MPN_FFT_BEST_READY 1
135 #endif
136
137 /*****************************************************************************/
138
139 #if ! defined (MPN_FFT_BEST_READY)
140 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
141 {
142 MUL_FFT_TABLE,
143 SQR_FFT_TABLE
144 };
145
146 int
147 mpn_fft_best_k (mp_size_t n, int sqr)
148 {
149 int i;
150
151 for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
152 if (n < mpn_fft_table[sqr][i])
153 return i + FFT_FIRST_K;
154
155 /* treat 4*last as one further entry */
156 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
157 return i + FFT_FIRST_K;
158 else
159 return i + FFT_FIRST_K + 1;
160 }
161 #endif
162
163 /*****************************************************************************/
164
165
166 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
167 i.e. smallest multiple of 2^k >= pl.
168
169 Don't declare static: needed by tuneup.
170 */
171
172 mp_size_t
173 mpn_fft_next_size (mp_size_t pl, int k)
174 {
175 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
176 return pl << k;
177 }
178
179
180 /* Initialize l[i][j] with bitrev(j) */
181 static void
182 mpn_fft_initl (int **l, int k)
183 {
184 int i, j, K;
185 int *li;
186
187 l[0][0] = 0;
188 for (i = 1, K = 1; i <= k; i++, K *= 2)
189 {
190 li = l[i];
191 for (j = 0; j < K; j++)
192 {
193 li[j] = 2 * l[i - 1][j];
194 li[K + j] = 1 + li[j];
195 }
196 }
197 }
198
199
200 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
201 Assumes a is semi-normalized, i.e. a[n] <= 1.
202 r and a must have n+1 limbs, and not overlap.
203 */
204 static void
205 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
206 {
207 unsigned int sh;
208 mp_size_t m;
209 mp_limb_t cc, rd;
210
211 sh = d % GMP_NUMB_BITS;
212 m = d / GMP_NUMB_BITS;
213
214 if (m >= n) /* negate */
215 {
216 /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh)
217 r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */
218
219 m -= n;
220 if (sh != 0)
221 {
222 /* no out shift below since a[n] <= 1 */
223 mpn_lshift (r, a + n - m, m + 1, sh);
224 rd = r[m];
225 cc = mpn_lshiftc (r + m, a, n - m, sh);
226 }
227 else
228 {
229 MPN_COPY (r, a + n - m, m);
230 rd = a[n];
231 mpn_com (r + m, a, n - m);
232 cc = 0;
233 }
234
235 /* add cc to r[0], and add rd to r[m] */
236
237 /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
238
239 r[n] = 0;
240 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
241 ++cc;
242 MPN_INCR_U (r, n + 1, cc);
243
244 ++rd;
245 /* rd might overflow when sh=GMP_NUMB_BITS-1 */
246 cc = rd + (rd == 0);
247 r = r + m + (rd == 0);
248 MPN_INCR_U (r, n + 1 - m - (rd == 0), cc);
249 }
250 else
251 {
252 /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh)
253 r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */
254 if (sh != 0)
255 {
256 /* no out bits below since a[n] <= 1 */
257 mpn_lshiftc (r, a + n - m, m + 1, sh);
258 rd = ~r[m];
259 /* {r, m+1} = {a+n-m, m+1} << sh */
260 cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
261 }
262 else
263 {
264 /* r[m] is not used below, but we save a test for m=0 */
265 mpn_com (r, a + n - m, m + 1);
266 rd = a[n];
267 MPN_COPY (r + m, a, n - m);
268 cc = 0;
269 }
270
271 /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
272
273 /* if m=0 we just have r[0]=a[n] << sh */
274 if (m != 0)
275 {
276 /* now add 1 in r[0], subtract 1 in r[m] */
277 if (cc-- == 0) /* then add 1 to r[0] */
278 cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
279 cc = mpn_sub_1 (r, r, m, cc) + 1;
280 /* add 1 to cc instead of rd since rd might overflow */
281 }
282
283 /* now subtract cc and rd from r[m..n] */
284
285 r[n] = 2; /* Add a value, to avoid borrow propagation */
286 MPN_DECR_U (r + m, n - m + 1, cc);
287 MPN_DECR_U (r + m, n - m + 1, rd);
288 /* Remove the added value, and check for a possible borrow. */
289 if (UNLIKELY ((r[n] -= 2) != 0))
290 {
291 mp_limb_t cy = -r[n];
292 /* cy should always be 1, except in the very unlikely case
293 m=n-1, r[m]=0, cc+rd>GMP_NUMB_MAX+1. Never triggered.
294 Is it actually possible? */
295 r[n] = 0;
296 MPN_INCR_U (r, n + 1, cy);
297 }
298 }
299 }
300
301 #if HAVE_NATIVE_mpn_add_n_sub_n
302 static inline void
303 mpn_fft_add_sub_modF (mp_ptr A0, mp_ptr Ai, mp_srcptr tp, mp_size_t n)
304 {
305 mp_limb_t cyas, c, x;
306
307 cyas = mpn_add_n_sub_n (A0, Ai, A0, tp, n);
308
309 c = A0[n] - tp[n] - (cyas & 1);
310 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
311 Ai[n] = x + c;
312 MPN_INCR_U (Ai, n + 1, x);
313
314 c = A0[n] + tp[n] + (cyas >> 1);
315 x = (c - 1) & -(c != 0);
316 A0[n] = c - x;
317 MPN_DECR_U (A0, n + 1, x);
318 }
319
320 #else /* ! HAVE_NATIVE_mpn_add_n_sub_n */
321
322 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
323 Assumes a and b are semi-normalized.
324 */
325 static inline void
326 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
327 {
328 mp_limb_t c, x;
329
330 c = a[n] + b[n] + mpn_add_n (r, a, b, n);
331 /* 0 <= c <= 3 */
332
333 #if 1
334 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
335 result is slower code, of course. But the following outsmarts GCC. */
336 x = (c - 1) & -(c != 0);
337 r[n] = c - x;
338 MPN_DECR_U (r, n + 1, x);
339 #endif
340 #if 0
341 if (c > 1)
342 {
343 r[n] = 1; /* r[n] - c = 1 */
344 MPN_DECR_U (r, n + 1, c - 1);
345 }
346 else
347 {
348 r[n] = c;
349 }
350 #endif
351 }
352
353 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
354 Assumes a and b are semi-normalized.
355 */
356 static inline void
357 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
358 {
359 mp_limb_t c, x;
360
361 c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
362 /* -2 <= c <= 1 */
363
364 #if 1
365 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
366 result is slower code, of course. But the following outsmarts GCC. */
367 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
368 r[n] = x + c;
369 MPN_INCR_U (r, n + 1, x);
370 #endif
371 #if 0
372 if ((c & GMP_LIMB_HIGHBIT) != 0)
373 {
374 r[n] = 0;
375 MPN_INCR_U (r, n + 1, -c);
376 }
377 else
378 {
379 r[n] = c;
380 }
381 #endif
382 }
383 #endif /* HAVE_NATIVE_mpn_add_n_sub_n */
384
385 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
386 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
387 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
388
389 static void
390 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
391 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
392 {
393 if (K == 2)
394 {
395 mp_limb_t cy;
396 #if HAVE_NATIVE_mpn_add_n_sub_n
397 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
398 #else
399 MPN_COPY (tp, Ap[0], n + 1);
400 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
401 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
402 #endif
403 if (Ap[0][n] > 1) /* can be 2 or 3 */
404 { /* Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); */
405 mp_limb_t cc = Ap[0][n] - 1;
406 Ap[0][n] = 1;
407 MPN_DECR_U (Ap[0], n + 1, cc);
408 }
409 if (cy) /* Ap[inc][n] can be -1 or -2 */
410 { /* Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); */
411 mp_limb_t cc = ~Ap[inc][n] + 1;
412 Ap[inc][n] = 0;
413 MPN_INCR_U (Ap[inc], n + 1, cc);
414 }
415 }
416 else
417 {
418 mp_size_t j, K2 = K >> 1;
419 int *lk = *ll;
420
421 mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp);
422 mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
423 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
424 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
425 for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
426 {
427 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
428 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
429 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
430 #if HAVE_NATIVE_mpn_add_n_sub_n
431 mpn_fft_add_sub_modF (Ap[0], Ap[inc], tp, n);
432 #else
433 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
434 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
435 #endif
436 }
437 }
438 }
439
440 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
441 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
442 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
443 tp must have space for 2*(n+1) limbs.
444 */
445
446
447 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
448 by subtracting that modulus if necessary.
449
450 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
451 borrow and the limbs must be zeroed out again. This will occur very
452 infrequently. */
453
454 static inline void
455 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
456 {
457 if (ap[n] != 0)
458 {
459 MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
460 if (ap[n] == 0)
461 {
462 /* This happens with very low probability; we have yet to trigger it,
463 and thereby make sure this code is correct. */
464 MPN_ZERO (ap, n);
465 ap[n] = 1;
466 }
467 else
468 ap[n] = 0;
469 }
470 }
471
472 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
473 static void
474 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
475 {
476 int i;
477 unsigned k;
478 int sqr = (ap == bp);
479 TMP_DECL;
480
481 TMP_MARK;
482
483 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
484 {
485 mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
486 int k;
487 int **fft_l, *tmp;
488 mp_ptr *Ap, *Bp, A, B, T;
489
490 k = mpn_fft_best_k (n, sqr);
491 K2 = (mp_size_t) 1 << k;
492 ASSERT_ALWAYS((n & (K2 - 1)) == 0);
493 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
494 M2 = n * GMP_NUMB_BITS >> k;
495 l = n >> k;
496 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
497 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
498 nprime2 = Nprime2 / GMP_NUMB_BITS;
499
500 /* we should ensure that nprime2 is a multiple of the next K */
501 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
502 {
503 mp_size_t K3;
504 for (;;)
505 {
506 K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
507 if ((nprime2 & (K3 - 1)) == 0)
508 break;
509 nprime2 = (nprime2 + K3 - 1) & -K3;
510 Nprime2 = nprime2 * GMP_LIMB_BITS;
511 /* warning: since nprime2 changed, K3 may change too! */
512 }
513 }
514 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
515
516 Mp2 = Nprime2 >> k;
517
518 Ap = TMP_BALLOC_MP_PTRS (K2);
519 Bp = TMP_BALLOC_MP_PTRS (K2);
520 A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
521 T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
522 B = A + ((nprime2 + 1) << k);
523 fft_l = TMP_BALLOC_TYPE (k + 1, int *);
524 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
525 for (i = 0; i <= k; i++)
526 {
527 fft_l[i] = tmp;
528 tmp += (mp_size_t) 1 << i;
529 }
530
531 mpn_fft_initl (fft_l, k);
532
533 TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
534 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
535 for (i = 0; i < K; i++, ap++, bp++)
536 {
537 mp_limb_t cy;
538 mpn_fft_normalize (*ap, n);
539 if (!sqr)
540 mpn_fft_normalize (*bp, n);
541
542 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
543 if (!sqr)
544 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
545
546 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
547 l, Mp2, fft_l, T, sqr);
548 (*ap)[n] = cy;
549 }
550 }
551 #if ! TUNE_PROGRAM_BUILD
552 else if (MPN_MULMOD_BKNP1_USABLE (n, k, MUL_FFT_MODF_THRESHOLD))
553 {
554 mp_ptr a;
555 mp_size_t n_k = n / k;
556
557 if (sqr)
558 {
559 mp_ptr tp = TMP_SALLOC_LIMBS (mpn_sqrmod_bknp1_itch (n));
560 for (i = 0; i < K; i++)
561 {
562 a = *ap++;
563 mpn_sqrmod_bknp1 (a, a, n_k, k, tp);
564 }
565 }
566 else
567 {
568 mp_ptr b, tp = TMP_SALLOC_LIMBS (mpn_mulmod_bknp1_itch (n));
569 for (i = 0; i < K; i++)
570 {
571 a = *ap++;
572 b = *bp++;
573 mpn_mulmod_bknp1 (a, a, b, n_k, k, tp);
574 }
575 }
576 }
577 #endif
578 else
579 {
580 mp_ptr a, b, tp, tpn;
581 mp_limb_t cc;
582 mp_size_t n2 = 2 * n;
583 tp = TMP_BALLOC_LIMBS (n2);
584 tpn = tp + n;
585 TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n));
586 for (i = 0; i < K; i++)
587 {
588 a = *ap++;
589 b = *bp++;
590 if (sqr)
591 mpn_sqr (tp, a, n);
592 else
593 mpn_mul_n (tp, b, a, n);
594 if (a[n] != 0)
595 cc = mpn_add_n (tpn, tpn, b, n);
596 else
597 cc = 0;
598 if (b[n] != 0)
599 cc += mpn_add_n (tpn, tpn, a, n) + a[n];
600 if (cc != 0)
601 {
602 cc = mpn_add_1 (tp, tp, n2, cc);
603 /* If mpn_add_1 give a carry (cc != 0),
604 the result (tp) is at most GMP_NUMB_MAX - 1,
605 so the following addition can't overflow.
606 */
607 tp[0] += cc;
608 }
609 cc = mpn_sub_n (a, tp, tpn, n);
610 a[n] = 0;
611 MPN_INCR_U (a, n + 1, cc);
612 }
613 }
614 TMP_FREE;
615 }
616
617
618 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
619 output: K*A[0] K*A[K-1] ... K*A[1].
620 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
621 This condition is also fulfilled at exit.
622 */
623 static void
624 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
625 {
626 if (K == 2)
627 {
628 mp_limb_t cy;
629 #if HAVE_NATIVE_mpn_add_n_sub_n
630 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
631 #else
632 MPN_COPY (tp, Ap[0], n + 1);
633 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
634 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
635 #endif
636 if (Ap[0][n] > 1) /* can be 2 or 3 */
637 { /* Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); */
638 mp_limb_t cc = Ap[0][n] - 1;
639 Ap[0][n] = 1;
640 MPN_DECR_U (Ap[0], n + 1, cc);
641 }
642 if (cy) /* Ap[1][n] can be -1 or -2 */
643 { /* Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); */
644 mp_limb_t cc = ~Ap[1][n] + 1;
645 Ap[1][n] = 0;
646 MPN_INCR_U (Ap[1], n + 1, cc);
647 }
648 }
649 else
650 {
651 mp_size_t j, K2 = K >> 1;
652
653 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp);
654 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
655 /* A[j] <- A[j] + omega^j A[j+K/2]
656 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
657 for (j = 0; j < K2; j++, Ap++)
658 {
659 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
660 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */
661 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
662 #if HAVE_NATIVE_mpn_add_n_sub_n
663 mpn_fft_add_sub_modF (Ap[0], Ap[K2], tp, n);
664 #else
665 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
666 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
667 #endif
668 }
669 }
670 }
671
672
673 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
674 static void
675 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
676 {
677 mp_bitcnt_t i;
678
679 ASSERT (r != a);
680 i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
681 mpn_fft_mul_2exp_modF (r, a, i, n);
682 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
683 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
684 mpn_fft_normalize (r, n);
685 }
686
687
688 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
689 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
690 then {rp,n}=0.
691 */
692 static mp_size_t
693 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
694 {
695 mp_size_t l, m, rpn;
696 mp_limb_t cc;
697
698 ASSERT ((n <= an) && (an <= 3 * n));
699 m = an - 2 * n;
700 if (m > 0)
701 {
702 l = n;
703 /* add {ap, m} and {ap+2n, m} in {rp, m} */
704 cc = mpn_add_n (rp, ap, ap + 2 * n, m);
705 /* copy {ap+m, n-m} to {rp+m, n-m} */
706 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
707 }
708 else
709 {
710 l = an - n; /* l <= n */
711 MPN_COPY (rp, ap, n);
712 rpn = 0;
713 }
714
715 /* remains to subtract {ap+n, l} from {rp, n+1} */
716 rpn -= mpn_sub (rp, rp, n, ap + n, l);
717 if (rpn < 0) /* necessarily rpn = -1 */
718 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
719 return rpn;
720 }
721
722 /* store in A[0..nprime] the first M bits from {n, nl},
723 in A[nprime+1..] the following M bits, ...
724 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
725 T must have space for at least (nprime + 1) limbs.
726 We must have nl <= 2*K*l.
727 */
728 static void
729 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
730 mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
731 mp_ptr T)
732 {
733 mp_size_t i, j;
734 mp_ptr tmp;
735 mp_size_t Kl = K * l;
736 TMP_DECL;
737 TMP_MARK;
738
739 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
740 {
741 mp_size_t dif = nl - Kl;
742
743 tmp = TMP_BALLOC_LIMBS(Kl + 1);
744 tmp[Kl] = 0;
745
746 #if ! WANT_OLD_FFT_FULL
747 ASSERT_ALWAYS (dif <= Kl);
748 #else
749 /* The comment "We must have nl <= 2*K*l." says that
750 ((dif = nl - Kl) > Kl) should never happen. */
751 if (UNLIKELY (dif > Kl))
752 {
753 mp_limb_signed_t cy;
754 int subp = 0;
755
756 cy = mpn_sub_n (tmp, n, n + Kl, Kl);
757 n += 2 * Kl;
758 dif -= Kl;
759
760 /* now dif > 0 */
761 while (dif > Kl)
762 {
763 if (subp)
764 cy += mpn_sub_n (tmp, tmp, n, Kl);
765 else
766 cy -= mpn_add_n (tmp, tmp, n, Kl);
767 subp ^= 1;
768 n += Kl;
769 dif -= Kl;
770 }
771 /* now dif <= Kl */
772 if (subp)
773 cy += mpn_sub (tmp, tmp, Kl, n, dif);
774 else
775 cy -= mpn_add (tmp, tmp, Kl, n, dif);
776 if (cy >= 0)
777 MPN_INCR_U (tmp, Kl + 1, cy);
778 else
779 {
780 tmp[Kl] = 1;
781 MPN_DECR_U (tmp, Kl + 1, -cy - 1);
782 }
783 }
784 else /* dif <= Kl, i.e. nl <= 2 * Kl */
785 #endif
786 {
787 mp_limb_t cy;
788 cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
789 MPN_INCR_U (tmp, Kl + 1, cy);
790 }
791 nl = Kl + 1;
792 n = tmp;
793 }
794 for (i = 0; i < K; i++)
795 {
796 Ap[i] = A;
797 /* store the next M bits of n into A[0..nprime] */
798 if (nl > 0) /* nl is the number of remaining limbs */
799 {
800 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
801 nl -= j;
802 MPN_COPY (T, n, j);
803 MPN_ZERO (T + j, nprime + 1 - j);
804 n += l;
805 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
806 }
807 else
808 MPN_ZERO (A, nprime + 1);
809 A += nprime + 1;
810 }
811 ASSERT_ALWAYS (nl == 0);
812 TMP_FREE;
813 }
814
815 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
816 op is pl limbs, its high bit is returned.
817 One must have pl = mpn_fft_next_size (pl, k).
818 T must have space for 2 * (nprime + 1) limbs.
819 */
820
821 static mp_limb_t
822 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
823 mp_ptr *Ap, mp_ptr *Bp, mp_ptr unusedA, mp_ptr B,
824 mp_size_t nprime, mp_size_t l, mp_size_t Mp,
825 int **fft_l, mp_ptr T, int sqr)
826 {
827 mp_size_t K, i, pla, lo, sh, j;
828 mp_ptr p;
829 mp_limb_t cc;
830
831 K = (mp_size_t) 1 << k;
832
833 /* direct fft's */
834 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
835 if (!sqr)
836 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
837
838 /* term to term multiplications */
839 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
840
841 /* inverse fft's */
842 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
843
844 /* division of terms after inverse fft */
845 Bp[0] = T + nprime + 1;
846 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
847 for (i = 1; i < K; i++)
848 {
849 Bp[i] = Ap[i - 1];
850 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
851 }
852
853 /* addition of terms in result p */
854 MPN_ZERO (T, nprime + 1);
855 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
856 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
857 MPN_ZERO (p, pla);
858 cc = 0; /* will accumulate the (signed) carry at p[pla] */
859 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
860 {
861 mp_ptr n = p + sh;
862
863 j = (K - i) & (K - 1);
864
865 cc += mpn_add (n, n, pla - sh, Bp[j], nprime + 1);
866 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
867 if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
868 { /* subtract 2^N'+1 */
869 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
870 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
871 }
872 }
873 if (cc == -CNST_LIMB(1))
874 {
875 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
876 {
877 /* p[pla-pl]...p[pla-1] are all zero */
878 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
879 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
880 }
881 }
882 else if (cc == 1)
883 {
884 if (pla >= 2 * pl)
885 {
886 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
887 ;
888 }
889 else
890 {
891 MPN_DECR_U (p + pla - pl, pl, cc);
892 }
893 }
894 else
895 ASSERT (cc == 0);
896
897 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
898 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
899 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
900 return mpn_fft_norm_modF (op, pl, p, pla);
901 }
902
903 /* return the lcm of a and 2^k */
904 static mp_bitcnt_t
905 mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
906 {
907 mp_bitcnt_t l = k;
908
909 while (a % 2 == 0 && k > 0)
910 {
911 a >>= 1;
912 k --;
913 }
914 return a << l;
915 }
916
917
918 mp_limb_t
919 mpn_mul_fft (mp_ptr op, mp_size_t pl,
920 mp_srcptr n, mp_size_t nl,
921 mp_srcptr m, mp_size_t ml,
922 int k)
923 {
924 int i;
925 mp_size_t K, maxLK;
926 mp_size_t N, Nprime, nprime, M, Mp, l;
927 mp_ptr *Ap, *Bp, A, T, B;
928 int **fft_l, *tmp;
929 int sqr = (n == m && nl == ml);
930 mp_limb_t h;
931 TMP_DECL;
932
933 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
934 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
935
936 TMP_MARK;
937 N = pl * GMP_NUMB_BITS;
938 fft_l = TMP_BALLOC_TYPE (k + 1, int *);
939 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
940 for (i = 0; i <= k; i++)
941 {
942 fft_l[i] = tmp;
943 tmp += (mp_size_t) 1 << i;
944 }
945
946 mpn_fft_initl (fft_l, k);
947 K = (mp_size_t) 1 << k;
948 M = N >> k; /* N = 2^k M */
949 l = 1 + (M - 1) / GMP_NUMB_BITS;
950 maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
951
952 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
953 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
954 nprime = Nprime / GMP_NUMB_BITS;
955 TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
956 N, K, M, l, maxLK, Nprime, nprime));
957 /* we should ensure that recursively, nprime is a multiple of the next K */
958 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
959 {
960 mp_size_t K2;
961 for (;;)
962 {
963 K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
964 if ((nprime & (K2 - 1)) == 0)
965 break;
966 nprime = (nprime + K2 - 1) & -K2;
967 Nprime = nprime * GMP_LIMB_BITS;
968 /* warning: since nprime changed, K2 may change too! */
969 }
970 TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
971 }
972 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
973
974 T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
975 Mp = Nprime >> k;
976
977 TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
978 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
979 printf (" temp space %ld\n", 2 * K * (nprime + 1)));
980
981 A = TMP_BALLOC_LIMBS (K * (nprime + 1));
982 Ap = TMP_BALLOC_MP_PTRS (K);
983 Bp = TMP_BALLOC_MP_PTRS (K);
984 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
985 if (sqr)
986 {
987 mp_size_t pla;
988 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
989 B = TMP_BALLOC_LIMBS (pla);
990 }
991 else
992 {
993 B = TMP_BALLOC_LIMBS (K * (nprime + 1));
994 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
995 }
996 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
997
998 TMP_FREE;
999 return h;
1000 }
1001
1002 #if WANT_OLD_FFT_FULL
1003 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
1004 void
1005 mpn_mul_fft_full (mp_ptr op,
1006 mp_srcptr n, mp_size_t nl,
1007 mp_srcptr m, mp_size_t ml)
1008 {
1009 mp_ptr pad_op;
1010 mp_size_t pl, pl2, pl3, l;
1011 mp_size_t cc, c2, oldcc;
1012 int k2, k3;
1013 int sqr = (n == m && nl == ml);
1014
1015 pl = nl + ml; /* total number of limbs of the result */
1016
1017 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
1018 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
1019 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
1020 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
1021 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
1022 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
1023 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
1024
1025 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
1026
1027 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
1028 do
1029 {
1030 pl2++;
1031 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
1032 pl2 = mpn_fft_next_size (pl2, k2);
1033 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
1034 thus pl2 / 2 is exact */
1035 k3 = mpn_fft_best_k (pl3, sqr);
1036 }
1037 while (mpn_fft_next_size (pl3, k3) != pl3);
1038
1039 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
1040 nl, ml, pl2, pl3, k2));
1041
1042 ASSERT_ALWAYS(pl3 <= pl);
1043 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */
1044 ASSERT(cc == 0);
1045 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
1046 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
1047 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */
1048 /* 0 <= cc <= 1 */
1049 ASSERT(0 <= cc && cc <= 1);
1050 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
1051 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
1052 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
1053 ASSERT(-1 <= cc && cc <= 1);
1054 if (cc < 0)
1055 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1056 ASSERT(0 <= cc && cc <= 1);
1057 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
1058 oldcc = cc;
1059 #if HAVE_NATIVE_mpn_add_n_sub_n
1060 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
1061 cc += c2 >> 1; /* carry out from high <- low + high */
1062 c2 = c2 & 1; /* borrow out from low <- low - high */
1063 #else
1064 {
1065 mp_ptr tmp;
1066 TMP_DECL;
1067
1068 TMP_MARK;
1069 tmp = TMP_BALLOC_LIMBS (l);
1070 MPN_COPY (tmp, pad_op, l);
1071 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l);
1072 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l);
1073 TMP_FREE;
1074 }
1075 #endif
1076 c2 += oldcc;
1077 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
1078 at pad_op + l, cc is the carry at pad_op + pl2 */
1079 /* 0 <= cc <= 2 */
1080 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
1081 /* -1 <= cc <= 2 */
1082 if (cc > 0)
1083 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
1084 /* now -1 <= cc <= 0 */
1085 if (cc < 0)
1086 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1087 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
1088 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
1089 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
1090 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1091 out below */
1092 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1093 if (cc) /* then cc=1 */
1094 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1095 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1096 mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1097 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1098 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1099 MPN_COPY (op + pl3, pad_op, pl - pl3);
1100 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1101 __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1102 /* since the final result has at most pl limbs, no carry out below */
1103 MPN_INCR_U (op + pl2, pl - pl2, (mp_limb_t) c2);
1104 }
1105 #endif