(root)/
Python-3.12.0/
Modules/
_decimal/
libmpdec/
fourstep.c
       1  /*
       2   * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
       3   *
       4   * Redistribution and use in source and binary forms, with or without
       5   * modification, are permitted provided that the following conditions
       6   * are met:
       7   *
       8   * 1. Redistributions of source code must retain the above copyright
       9   *    notice, this list of conditions and the following disclaimer.
      10   *
      11   * 2. Redistributions in binary form must reproduce the above copyright
      12   *    notice, this list of conditions and the following disclaimer in the
      13   *    documentation and/or other materials provided with the distribution.
      14   *
      15   * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
      16   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
      17   * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
      18   * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
      19   * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
      20   * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
      21   * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
      22   * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
      23   * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
      24   * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
      25   * SUCH DAMAGE.
      26   */
      27  
      28  
      29  #include "mpdecimal.h"
      30  
      31  #include <assert.h>
      32  
      33  #include "constants.h"
      34  #include "fourstep.h"
      35  #include "numbertheory.h"
      36  #include "sixstep.h"
      37  #include "umodarith.h"
      38  
      39  
      40  /* Bignum: Cache efficient Matrix Fourier Transform for arrays of the
      41     form 3 * 2**n (See literature/matrix-transform.txt). */
      42  
      43  
      44  #ifndef PPRO
      45  static inline void
      46  std_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3,
      47                mpd_uint_t w3table[3], mpd_uint_t umod)
      48  {
      49      mpd_uint_t r1, r2;
      50      mpd_uint_t w;
      51      mpd_uint_t s, tmp;
      52  
      53  
      54      /* k = 0 -> w = 1 */
      55      s = *x1;
      56      s = addmod(s, *x2, umod);
      57      s = addmod(s, *x3, umod);
      58  
      59      r1 = s;
      60  
      61      /* k = 1 */
      62      s = *x1;
      63  
      64      w = w3table[1];
      65      tmp = MULMOD(*x2, w);
      66      s = addmod(s, tmp, umod);
      67  
      68      w = w3table[2];
      69      tmp = MULMOD(*x3, w);
      70      s = addmod(s, tmp, umod);
      71  
      72      r2 = s;
      73  
      74      /* k = 2 */
      75      s = *x1;
      76  
      77      w = w3table[2];
      78      tmp = MULMOD(*x2, w);
      79      s = addmod(s, tmp, umod);
      80  
      81      w = w3table[1];
      82      tmp = MULMOD(*x3, w);
      83      s = addmod(s, tmp, umod);
      84  
      85      *x3 = s;
      86      *x2 = r2;
      87      *x1 = r1;
      88  }
      89  #else /* PPRO */
      90  static inline void
      91  ppro_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3, mpd_uint_t w3table[3],
      92                 mpd_uint_t umod, double *dmod, uint32_t dinvmod[3])
      93  {
      94      mpd_uint_t r1, r2;
      95      mpd_uint_t w;
      96      mpd_uint_t s, tmp;
      97  
      98  
      99      /* k = 0 -> w = 1 */
     100      s = *x1;
     101      s = addmod(s, *x2, umod);
     102      s = addmod(s, *x3, umod);
     103  
     104      r1 = s;
     105  
     106      /* k = 1 */
     107      s = *x1;
     108  
     109      w = w3table[1];
     110      tmp = ppro_mulmod(*x2, w, dmod, dinvmod);
     111      s = addmod(s, tmp, umod);
     112  
     113      w = w3table[2];
     114      tmp = ppro_mulmod(*x3, w, dmod, dinvmod);
     115      s = addmod(s, tmp, umod);
     116  
     117      r2 = s;
     118  
     119      /* k = 2 */
     120      s = *x1;
     121  
     122      w = w3table[2];
     123      tmp = ppro_mulmod(*x2, w, dmod, dinvmod);
     124      s = addmod(s, tmp, umod);
     125  
     126      w = w3table[1];
     127      tmp = ppro_mulmod(*x3, w, dmod, dinvmod);
     128      s = addmod(s, tmp, umod);
     129  
     130      *x3 = s;
     131      *x2 = r2;
     132      *x1 = r1;
     133  }
     134  #endif
     135  
     136  
     137  /* forward transform, sign = -1; transform length = 3 * 2**n */
     138  int
     139  four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum)
     140  {
     141      mpd_size_t R = 3; /* number of rows */
     142      mpd_size_t C = n / 3; /* number of columns */
     143      mpd_uint_t w3table[3];
     144      mpd_uint_t kernel, w0, w1, wstep;
     145      mpd_uint_t *s, *p0, *p1, *p2;
     146      mpd_uint_t umod;
     147  #ifdef PPRO
     148      double dmod;
     149      uint32_t dinvmod[3];
     150  #endif
     151      mpd_size_t i, k;
     152  
     153  
     154      assert(n >= 48);
     155      assert(n <= 3*MPD_MAXTRANSFORM_2N);
     156  
     157  
     158      /* Length R transform on the columns. */
     159      SETMODULUS(modnum);
     160      _mpd_init_w3table(w3table, -1, modnum);
     161      for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) {
     162  
     163          SIZE3_NTT(p0, p1, p2, w3table);
     164      }
     165  
     166      /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */
     167      kernel = _mpd_getkernel(n, -1, modnum);
     168      for (i = 1; i < R; i++) {
     169          w0 = 1;                  /* r**(i*0): initial value for k=0 */
     170          w1 = POWMOD(kernel, i);  /* r**(i*1): initial value for k=1 */
     171          wstep = MULMOD(w1, w1);  /* r**(2*i) */
     172          for (k = 0; k < C-1; k += 2) {
     173              mpd_uint_t x0 = a[i*C+k];
     174              mpd_uint_t x1 = a[i*C+k+1];
     175              MULMOD2(&x0, w0, &x1, w1);
     176              MULMOD2C(&w0, &w1, wstep);  /* r**(i*(k+2)) = r**(i*k) * r**(2*i) */
     177              a[i*C+k] = x0;
     178              a[i*C+k+1] = x1;
     179          }
     180      }
     181  
     182      /* Length C transform on the rows. */
     183      for (s = a; s < a+n; s += C) {
     184          if (!six_step_fnt(s, C, modnum)) {
     185              return 0;
     186          }
     187      }
     188  
     189  #if 0
     190      /* An unordered transform is sufficient for convolution. */
     191      /* Transpose the matrix. */
     192      #include "transpose.h"
     193      transpose_3xpow2(a, R, C);
     194  #endif
     195  
     196      return 1;
     197  }
     198  
     199  /* backward transform, sign = 1; transform length = 3 * 2**n */
     200  int
     201  inv_four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum)
     202  {
     203      mpd_size_t R = 3; /* number of rows */
     204      mpd_size_t C = n / 3; /* number of columns */
     205      mpd_uint_t w3table[3];
     206      mpd_uint_t kernel, w0, w1, wstep;
     207      mpd_uint_t *s, *p0, *p1, *p2;
     208      mpd_uint_t umod;
     209  #ifdef PPRO
     210      double dmod;
     211      uint32_t dinvmod[3];
     212  #endif
     213      mpd_size_t i, k;
     214  
     215  
     216      assert(n >= 48);
     217      assert(n <= 3*MPD_MAXTRANSFORM_2N);
     218  
     219  
     220  #if 0
     221      /* An unordered transform is sufficient for convolution. */
     222      /* Transpose the matrix, producing an R*C matrix. */
     223      #include "transpose.h"
     224      transpose_3xpow2(a, C, R);
     225  #endif
     226  
     227      /* Length C transform on the rows. */
     228      for (s = a; s < a+n; s += C) {
     229          if (!inv_six_step_fnt(s, C, modnum)) {
     230              return 0;
     231          }
     232      }
     233  
     234      /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */
     235      SETMODULUS(modnum);
     236      kernel = _mpd_getkernel(n, 1, modnum);
     237      for (i = 1; i < R; i++) {
     238          w0 = 1;
     239          w1 = POWMOD(kernel, i);
     240          wstep = MULMOD(w1, w1);
     241          for (k = 0; k < C; k += 2) {
     242              mpd_uint_t x0 = a[i*C+k];
     243              mpd_uint_t x1 = a[i*C+k+1];
     244              MULMOD2(&x0, w0, &x1, w1);
     245              MULMOD2C(&w0, &w1, wstep);
     246              a[i*C+k] = x0;
     247              a[i*C+k+1] = x1;
     248          }
     249      }
     250  
     251      /* Length R transform on the columns. */
     252      _mpd_init_w3table(w3table, 1, modnum);
     253      for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) {
     254  
     255          SIZE3_NTT(p0, p1, p2, w3table);
     256      }
     257  
     258      return 1;
     259  }