(root)/
Python-3.12.0/
Modules/
_decimal/
libmpdec/
umodarith.h
       1  /*
       2   * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
       3   *
       4   * Redistribution and use in source and binary forms, with or without
       5   * modification, are permitted provided that the following conditions
       6   * are met:
       7   *
       8   * 1. Redistributions of source code must retain the above copyright
       9   *    notice, this list of conditions and the following disclaimer.
      10   *
      11   * 2. Redistributions in binary form must reproduce the above copyright
      12   *    notice, this list of conditions and the following disclaimer in the
      13   *    documentation and/or other materials provided with the distribution.
      14   *
      15   * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
      16   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
      17   * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
      18   * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
      19   * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
      20   * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
      21   * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
      22   * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
      23   * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
      24   * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
      25   * SUCH DAMAGE.
      26   */
      27  
      28  
      29  #ifndef LIBMPDEC_UMODARITH_H_
      30  #define LIBMPDEC_UMODARITH_H_
      31  
      32  
      33  #include "mpdecimal.h"
      34  
      35  #include "constants.h"
      36  #include "typearith.h"
      37  
      38  
      39  /* Bignum: Low level routines for unsigned modular arithmetic. These are
      40     used in the fast convolution functions for very large coefficients. */
      41  
      42  
      43  /**************************************************************************/
      44  /*                        ANSI modular arithmetic                         */
      45  /**************************************************************************/
      46  
      47  
      48  /*
      49   * Restrictions: a < m and b < m
      50   * ACL2 proof: umodarith.lisp: addmod-correct
      51   */
      52  static inline mpd_uint_t
      53  addmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
      54  {
      55      mpd_uint_t s;
      56  
      57      s = a + b;
      58      s = (s < a) ? s - m : s;
      59      s = (s >= m) ? s - m : s;
      60  
      61      return s;
      62  }
      63  
      64  /*
      65   * Restrictions: a < m and b < m
      66   * ACL2 proof: umodarith.lisp: submod-2-correct
      67   */
      68  static inline mpd_uint_t
      69  submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
      70  {
      71      mpd_uint_t d;
      72  
      73      d = a - b;
      74      d = (a < b) ? d + m : d;
      75  
      76      return d;
      77  }
      78  
      79  /*
      80   * Restrictions: a < 2m and b < 2m
      81   * ACL2 proof: umodarith.lisp: section ext-submod
      82   */
      83  static inline mpd_uint_t
      84  ext_submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
      85  {
      86      mpd_uint_t d;
      87  
      88      a = (a >= m) ? a - m : a;
      89      b = (b >= m) ? b - m : b;
      90  
      91      d = a - b;
      92      d = (a < b) ? d + m : d;
      93  
      94      return d;
      95  }
      96  
      97  /*
      98   * Reduce double word modulo m.
      99   * Restrictions: m != 0
     100   * ACL2 proof: umodarith.lisp: section dw-reduce
     101   */
     102  static inline mpd_uint_t
     103  dw_reduce(mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
     104  {
     105      mpd_uint_t r1, r2, w;
     106  
     107      _mpd_div_word(&w, &r1, hi, m);
     108      _mpd_div_words(&w, &r2, r1, lo, m);
     109  
     110      return r2;
     111  }
     112  
     113  /*
     114   * Subtract double word from a.
     115   * Restrictions: a < m
     116   * ACL2 proof: umodarith.lisp: section dw-submod
     117   */
     118  static inline mpd_uint_t
     119  dw_submod(mpd_uint_t a, mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
     120  {
     121      mpd_uint_t d, r;
     122  
     123      r = dw_reduce(hi, lo, m);
     124      d = a - r;
     125      d = (a < r) ? d + m : d;
     126  
     127      return d;
     128  }
     129  
     130  #ifdef CONFIG_64
     131  
     132  /**************************************************************************/
     133  /*                        64-bit modular arithmetic                       */
     134  /**************************************************************************/
     135  
     136  /*
     137   * A proof of the algorithm is in literature/mulmod-64.txt. An ACL2
     138   * proof is in umodarith.lisp: section "Fast modular reduction".
     139   *
     140   * Algorithm: calculate (a * b) % p:
     141   *
     142   *   a) hi, lo <- a * b       # Calculate a * b.
     143   *
     144   *   b) hi, lo <-  R(hi, lo)  # Reduce modulo p.
     145   *
     146   *   c) Repeat step b) until 0 <= hi * 2**64 + lo < 2*p.
     147   *
     148   *   d) If the result is less than p, return lo. Otherwise return lo - p.
     149   */
     150  
     151  static inline mpd_uint_t
     152  x64_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
     153  {
     154      mpd_uint_t hi, lo, x, y;
     155  
     156  
     157      _mpd_mul_words(&hi, &lo, a, b);
     158  
     159      if (m & (1ULL<<32)) { /* P1 */
     160  
     161          /* first reduction */
     162          x = y = hi;
     163          hi >>= 32;
     164  
     165          x = lo - x;
     166          if (x > lo) hi--;
     167  
     168          y <<= 32;
     169          lo = y + x;
     170          if (lo < y) hi++;
     171  
     172          /* second reduction */
     173          x = y = hi;
     174          hi >>= 32;
     175  
     176          x = lo - x;
     177          if (x > lo) hi--;
     178  
     179          y <<= 32;
     180          lo = y + x;
     181          if (lo < y) hi++;
     182  
     183          return (hi || lo >= m ? lo - m : lo);
     184      }
     185      else if (m & (1ULL<<34)) { /* P2 */
     186  
     187          /* first reduction */
     188          x = y = hi;
     189          hi >>= 30;
     190  
     191          x = lo - x;
     192          if (x > lo) hi--;
     193  
     194          y <<= 34;
     195          lo = y + x;
     196          if (lo < y) hi++;
     197  
     198          /* second reduction */
     199          x = y = hi;
     200          hi >>= 30;
     201  
     202          x = lo - x;
     203          if (x > lo) hi--;
     204  
     205          y <<= 34;
     206          lo = y + x;
     207          if (lo < y) hi++;
     208  
     209          /* third reduction */
     210          x = y = hi;
     211          hi >>= 30;
     212  
     213          x = lo - x;
     214          if (x > lo) hi--;
     215  
     216          y <<= 34;
     217          lo = y + x;
     218          if (lo < y) hi++;
     219  
     220          return (hi || lo >= m ? lo - m : lo);
     221      }
     222      else { /* P3 */
     223  
     224          /* first reduction */
     225          x = y = hi;
     226          hi >>= 24;
     227  
     228          x = lo - x;
     229          if (x > lo) hi--;
     230  
     231          y <<= 40;
     232          lo = y + x;
     233          if (lo < y) hi++;
     234  
     235          /* second reduction */
     236          x = y = hi;
     237          hi >>= 24;
     238  
     239          x = lo - x;
     240          if (x > lo) hi--;
     241  
     242          y <<= 40;
     243          lo = y + x;
     244          if (lo < y) hi++;
     245  
     246          /* third reduction */
     247          x = y = hi;
     248          hi >>= 24;
     249  
     250          x = lo - x;
     251          if (x > lo) hi--;
     252  
     253          y <<= 40;
     254          lo = y + x;
     255          if (lo < y) hi++;
     256  
     257          return (hi || lo >= m ? lo - m : lo);
     258      }
     259  }
     260  
     261  static inline void
     262  x64_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
     263  {
     264      *a = x64_mulmod(*a, w, m);
     265      *b = x64_mulmod(*b, w, m);
     266  }
     267  
     268  static inline void
     269  x64_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
     270              mpd_uint_t m)
     271  {
     272      *a0 = x64_mulmod(*a0, b0, m);
     273      *a1 = x64_mulmod(*a1, b1, m);
     274  }
     275  
     276  static inline mpd_uint_t
     277  x64_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
     278  {
     279      mpd_uint_t r = 1;
     280  
     281      while (exp > 0) {
     282          if (exp & 1)
     283              r = x64_mulmod(r, base, umod);
     284          base = x64_mulmod(base, base, umod);
     285          exp >>= 1;
     286      }
     287  
     288      return r;
     289  }
     290  
     291  /* END CONFIG_64 */
     292  #else /* CONFIG_32 */
     293  
     294  
     295  /**************************************************************************/
     296  /*                        32-bit modular arithmetic                       */
     297  /**************************************************************************/
     298  
     299  #if defined(ANSI)
     300  #if !defined(LEGACY_COMPILER)
     301  /* HAVE_UINT64_T */
     302  static inline mpd_uint_t
     303  std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
     304  {
     305      return ((mpd_uuint_t) a * b) % m;
     306  }
     307  
     308  static inline void
     309  std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
     310  {
     311      *a = ((mpd_uuint_t) *a * w) % m;
     312      *b = ((mpd_uuint_t) *b * w) % m;
     313  }
     314  
     315  static inline void
     316  std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
     317              mpd_uint_t m)
     318  {
     319      *a0 = ((mpd_uuint_t) *a0 * b0) % m;
     320      *a1 = ((mpd_uuint_t) *a1 * b1) % m;
     321  }
     322  /* END HAVE_UINT64_T */
     323  #else
     324  /* LEGACY_COMPILER */
     325  static inline mpd_uint_t
     326  std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
     327  {
     328      mpd_uint_t hi, lo, q, r;
     329      _mpd_mul_words(&hi, &lo, a, b);
     330      _mpd_div_words(&q, &r, hi, lo, m);
     331      return r;
     332  }
     333  
     334  static inline void
     335  std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
     336  {
     337      *a = std_mulmod(*a, w, m);
     338      *b = std_mulmod(*b, w, m);
     339  }
     340  
     341  static inline void
     342  std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
     343              mpd_uint_t m)
     344  {
     345      *a0 = std_mulmod(*a0, b0, m);
     346      *a1 = std_mulmod(*a1, b1, m);
     347  }
     348  /* END LEGACY_COMPILER */
     349  #endif
     350  
     351  static inline mpd_uint_t
     352  std_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
     353  {
     354      mpd_uint_t r = 1;
     355  
     356      while (exp > 0) {
     357          if (exp & 1)
     358              r = std_mulmod(r, base, umod);
     359          base = std_mulmod(base, base, umod);
     360          exp >>= 1;
     361      }
     362  
     363      return r;
     364  }
     365  #endif /* ANSI CONFIG_32 */
     366  
     367  
     368  /**************************************************************************/
     369  /*                    Pentium Pro modular arithmetic                      */
     370  /**************************************************************************/
     371  
     372  /*
     373   * A proof of the algorithm is in literature/mulmod-ppro.txt. The FPU
     374   * control word must be set to 64-bit precision and truncation mode
     375   * prior to using these functions.
     376   *
     377   * Algorithm: calculate (a * b) % p:
     378   *
     379   *   p    := prime < 2**31
     380   *   pinv := (long double)1.0 / p (precalculated)
     381   *
     382   *   a) n = a * b              # Calculate exact product.
     383   *   b) qest = n * pinv        # Calculate estimate for q = n / p.
     384   *   c) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.
     385   *   d) r = n - q * p          # Calculate remainder.
     386   *
     387   * Remarks:
     388   *
     389   *   - p = dmod and pinv = dinvmod.
     390   *   - dinvmod points to an array of three uint32_t, which is interpreted
     391   *     as an 80 bit long double by fldt.
     392   *   - Intel compilers prior to version 11 do not seem to handle the
     393   *     __GNUC__ inline assembly correctly.
     394   *   - random tests are provided in tests/extended/ppro_mulmod.c
     395   */
     396  
     397  #if defined(PPRO)
     398  #if defined(ASM)
     399  
     400  /* Return (a * b) % dmod */
     401  static inline mpd_uint_t
     402  ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
     403  {
     404      mpd_uint_t retval;
     405  
     406      __asm__ (
     407              "fildl  %2\n\t"
     408              "fildl  %1\n\t"
     409              "fmulp  %%st, %%st(1)\n\t"
     410              "fldt   (%4)\n\t"
     411              "fmul   %%st(1), %%st\n\t"
     412              "flds   %5\n\t"
     413              "fadd   %%st, %%st(1)\n\t"
     414              "fsubrp %%st, %%st(1)\n\t"
     415              "fldl   (%3)\n\t"
     416              "fmulp  %%st, %%st(1)\n\t"
     417              "fsubrp %%st, %%st(1)\n\t"
     418              "fistpl %0\n\t"
     419              : "=m" (retval)
     420              : "m" (a), "m" (b), "r" (dmod), "r" (dinvmod), "m" (MPD_TWO63)
     421              : "st", "memory"
     422      );
     423  
     424      return retval;
     425  }
     426  
     427  /*
     428   * Two modular multiplications in parallel:
     429   *      *a0 = (*a0 * w) % dmod
     430   *      *a1 = (*a1 * w) % dmod
     431   */
     432  static inline void
     433  ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
     434                double *dmod, uint32_t *dinvmod)
     435  {
     436      __asm__ (
     437              "fildl  %2\n\t"
     438              "fildl  (%1)\n\t"
     439              "fmul   %%st(1), %%st\n\t"
     440              "fxch   %%st(1)\n\t"
     441              "fildl  (%0)\n\t"
     442              "fmulp  %%st, %%st(1) \n\t"
     443              "fldt   (%4)\n\t"
     444              "flds   %5\n\t"
     445              "fld    %%st(2)\n\t"
     446              "fmul   %%st(2)\n\t"
     447              "fadd   %%st(1)\n\t"
     448              "fsub   %%st(1)\n\t"
     449              "fmull  (%3)\n\t"
     450              "fsubrp %%st, %%st(3)\n\t"
     451              "fxch   %%st(2)\n\t"
     452              "fistpl (%0)\n\t"
     453              "fmul   %%st(2)\n\t"
     454              "fadd   %%st(1)\n\t"
     455              "fsubp  %%st, %%st(1)\n\t"
     456              "fmull  (%3)\n\t"
     457              "fsubrp %%st, %%st(1)\n\t"
     458              "fistpl (%1)\n\t"
     459              : : "r" (a0), "r" (a1), "m" (w),
     460                  "r" (dmod), "r" (dinvmod),
     461                  "m" (MPD_TWO63)
     462              : "st", "memory"
     463      );
     464  }
     465  
     466  /*
     467   * Two modular multiplications in parallel:
     468   *      *a0 = (*a0 * b0) % dmod
     469   *      *a1 = (*a1 * b1) % dmod
     470   */
     471  static inline void
     472  ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
     473               double *dmod, uint32_t *dinvmod)
     474  {
     475      __asm__ (
     476              "fildl  %3\n\t"
     477              "fildl  (%2)\n\t"
     478              "fmulp  %%st, %%st(1)\n\t"
     479              "fildl  %1\n\t"
     480              "fildl  (%0)\n\t"
     481              "fmulp  %%st, %%st(1)\n\t"
     482              "fldt   (%5)\n\t"
     483              "fld    %%st(2)\n\t"
     484              "fmul   %%st(1), %%st\n\t"
     485              "fxch   %%st(1)\n\t"
     486              "fmul   %%st(2), %%st\n\t"
     487              "flds   %6\n\t"
     488              "fldl   (%4)\n\t"
     489              "fxch   %%st(3)\n\t"
     490              "fadd   %%st(1), %%st\n\t"
     491              "fxch   %%st(2)\n\t"
     492              "fadd   %%st(1), %%st\n\t"
     493              "fxch   %%st(2)\n\t"
     494              "fsub   %%st(1), %%st\n\t"
     495              "fxch   %%st(2)\n\t"
     496              "fsubp  %%st, %%st(1)\n\t"
     497              "fxch   %%st(1)\n\t"
     498              "fmul   %%st(2), %%st\n\t"
     499              "fxch   %%st(1)\n\t"
     500              "fmulp  %%st, %%st(2)\n\t"
     501              "fsubrp %%st, %%st(3)\n\t"
     502              "fsubrp %%st, %%st(1)\n\t"
     503              "fxch   %%st(1)\n\t"
     504              "fistpl (%2)\n\t"
     505              "fistpl (%0)\n\t"
     506              : : "r" (a0), "m" (b0), "r" (a1), "m" (b1),
     507                  "r" (dmod), "r" (dinvmod),
     508                  "m" (MPD_TWO63)
     509              : "st", "memory"
     510      );
     511  }
     512  /* END PPRO GCC ASM */
     513  #elif defined(MASM)
     514  
     515  /* Return (a * b) % dmod */
     516  static inline mpd_uint_t __cdecl
     517  ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
     518  {
     519      mpd_uint_t retval;
     520  
     521      __asm {
     522          mov     eax, dinvmod
     523          mov     edx, dmod
     524          fild    b
     525          fild    a
     526          fmulp   st(1), st
     527          fld     TBYTE PTR [eax]
     528          fmul    st, st(1)
     529          fld     MPD_TWO63
     530          fadd    st(1), st
     531          fsubp   st(1), st
     532          fld     QWORD PTR [edx]
     533          fmulp   st(1), st
     534          fsubp   st(1), st
     535          fistp   retval
     536      }
     537  
     538      return retval;
     539  }
     540  
     541  /*
     542   * Two modular multiplications in parallel:
     543   *      *a0 = (*a0 * w) % dmod
     544   *      *a1 = (*a1 * w) % dmod
     545   */
     546  static inline mpd_uint_t __cdecl
     547  ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
     548                double *dmod, uint32_t *dinvmod)
     549  {
     550      __asm {
     551          mov     ecx, dmod
     552          mov     edx, a1
     553          mov     ebx, dinvmod
     554          mov     eax, a0
     555          fild    w
     556          fild    DWORD PTR [edx]
     557          fmul    st, st(1)
     558          fxch    st(1)
     559          fild    DWORD PTR [eax]
     560          fmulp   st(1), st
     561          fld     TBYTE PTR [ebx]
     562          fld     MPD_TWO63
     563          fld     st(2)
     564          fmul    st, st(2)
     565          fadd    st, st(1)
     566          fsub    st, st(1)
     567          fmul    QWORD PTR [ecx]
     568          fsubp   st(3), st
     569          fxch    st(2)
     570          fistp   DWORD PTR [eax]
     571          fmul    st, st(2)
     572          fadd    st, st(1)
     573          fsubrp  st(1), st
     574          fmul    QWORD PTR [ecx]
     575          fsubp   st(1), st
     576          fistp   DWORD PTR [edx]
     577      }
     578  }
     579  
     580  /*
     581   * Two modular multiplications in parallel:
     582   *      *a0 = (*a0 * b0) % dmod
     583   *      *a1 = (*a1 * b1) % dmod
     584   */
     585  static inline void __cdecl
     586  ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
     587               double *dmod, uint32_t *dinvmod)
     588  {
     589      __asm {
     590          mov     ecx, dmod
     591          mov     edx, a1
     592          mov     ebx, dinvmod
     593          mov     eax, a0
     594          fild    b1
     595          fild    DWORD PTR [edx]
     596          fmulp   st(1), st
     597          fild    b0
     598          fild    DWORD PTR [eax]
     599          fmulp   st(1), st
     600          fld     TBYTE PTR [ebx]
     601          fld     st(2)
     602          fmul    st, st(1)
     603          fxch    st(1)
     604          fmul    st, st(2)
     605          fld     DWORD PTR MPD_TWO63
     606          fld     QWORD PTR [ecx]
     607          fxch    st(3)
     608          fadd    st, st(1)
     609          fxch    st(2)
     610          fadd    st, st(1)
     611          fxch    st(2)
     612          fsub    st, st(1)
     613          fxch    st(2)
     614          fsubrp  st(1), st
     615          fxch    st(1)
     616          fmul    st, st(2)
     617          fxch    st(1)
     618          fmulp   st(2), st
     619          fsubp   st(3), st
     620          fsubp   st(1), st
     621          fxch    st(1)
     622          fistp   DWORD PTR [edx]
     623          fistp   DWORD PTR [eax]
     624      }
     625  }
     626  #endif /* PPRO MASM (_MSC_VER) */
     627  
     628  
     629  /* Return (base ** exp) % dmod */
     630  static inline mpd_uint_t
     631  ppro_powmod(mpd_uint_t base, mpd_uint_t exp, double *dmod, uint32_t *dinvmod)
     632  {
     633      mpd_uint_t r = 1;
     634  
     635      while (exp > 0) {
     636          if (exp & 1)
     637              r = ppro_mulmod(r, base, dmod, dinvmod);
     638          base = ppro_mulmod(base, base, dmod, dinvmod);
     639          exp >>= 1;
     640      }
     641  
     642      return r;
     643  }
     644  #endif /* PPRO */
     645  #endif /* CONFIG_32 */
     646  
     647  
     648  #endif /* LIBMPDEC_UMODARITH_H_ */