(root)/
gmp-6.3.0/
mpn/
generic/
sec_powm.c
       1  /* mpn_sec_powm -- Compute R = U^E mod M.  Secure variant, side-channel silent
       2     under the assumption that the multiply instruction is side channel silent.
       3  
       4     Contributed to the GNU project by Torbjörn Granlund.
       5  
       6  Copyright 2007-2009, 2011-2014, 2018-2019, 2021 Free Software Foundation, Inc.
       7  
       8  This file is part of the GNU MP Library.
       9  
      10  The GNU MP Library is free software; you can redistribute it and/or modify
      11  it under the terms of either:
      12  
      13    * the GNU Lesser General Public License as published by the Free
      14      Software Foundation; either version 3 of the License, or (at your
      15      option) any later version.
      16  
      17  or
      18  
      19    * the GNU General Public License as published by the Free Software
      20      Foundation; either version 2 of the License, or (at your option) any
      21      later version.
      22  
      23  or both in parallel, as here.
      24  
      25  The GNU MP Library is distributed in the hope that it will be useful, but
      26  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
      27  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
      28  for more details.
      29  
      30  You should have received copies of the GNU General Public License and the
      31  GNU Lesser General Public License along with the GNU MP Library.  If not,
      32  see https://www.gnu.org/licenses/.  */
      33  
      34  
      35  /*
      36    BASIC ALGORITHM, Compute U^E mod M, where M < B^n is odd.
      37  
      38    1. T <- (B^n * U) mod M; convert to REDC form
      39  
      40    2. Compute table U^0, U^1, U^2... of floor(log(E))-dependent size
      41  
      42    3. While there are more bits in E
      43         W <- power left-to-right base-k
      44  
      45    The article "Defeating modexp side-channel attacks with data-independent
      46    execution traces", https://gmplib.org/~tege/modexp-silent.pdf, has details.
      47  
      48  
      49    TODO:
      50  
      51     * Make getbits a macro, thereby allowing it to update the index operand.
      52       That will simplify the code using getbits.  (Perhaps make getbits' sibling
      53       getbit then have similar form, for symmetry.)
      54  
      55     * Choose window size without looping.  (Superoptimize or think(tm).)
      56  
      57     * REDC_1_TO_REDC_2_THRESHOLD might actually represent the cutoff between
      58       redc_1 and redc_n.  On such systems, we will switch to redc_2 causing
      59       slowdown.
      60  */
      61  
      62  #include "gmp-impl.h"
      63  #include "longlong.h"
      64  
      65  #undef MPN_REDC_1_SEC
      66  #if HAVE_NATIVE_mpn_sbpi1_bdiv_r
      67  #define MPN_REDC_1_SEC(rp, up, mp, n, invm)				\
      68    do {									\
      69      mp_limb_t cy;							\
      70      cy = mpn_sbpi1_bdiv_r (up, 2 * n, mp, n, invm);			\
      71      mpn_cnd_sub_n (cy, rp, up + n, mp, n);				\
      72    } while (0)
      73  #else
      74  #define MPN_REDC_1_SEC(rp, up, mp, n, invm)				\
      75    do {									\
      76      mp_limb_t cy;							\
      77      cy = mpn_redc_1 (rp, up, mp, n, invm);				\
      78      mpn_cnd_sub_n (cy, rp, rp, mp, n);					\
      79    } while (0)
      80  #endif
      81  
      82  #if HAVE_NATIVE_mpn_addmul_2 || HAVE_NATIVE_mpn_redc_2
      83  #undef MPN_REDC_2_SEC
      84  #define MPN_REDC_2_SEC(rp, up, mp, n, mip)				\
      85    do {									\
      86      mp_limb_t cy;							\
      87      cy = mpn_redc_2 (rp, up, mp, n, mip);				\
      88      mpn_cnd_sub_n (cy, rp, rp, mp, n);					\
      89    } while (0)
      90  #else
      91  #define MPN_REDC_2_SEC(rp, up, mp, n, mip) /* empty */
      92  #undef REDC_1_TO_REDC_2_THRESHOLD
      93  #define REDC_1_TO_REDC_2_THRESHOLD MP_SIZE_T_MAX
      94  #endif
      95  
      96  /* Define our own mpn squaring function.  We do this since we cannot use a
      97     native mpn_sqr_basecase over TUNE_SQR_TOOM2_MAX, or a non-native one over
      98     SQR_TOOM2_THRESHOLD.  This is so because of fixed size stack allocations
      99     made inside mpn_sqr_basecase.  */
     100  
     101  #if ! HAVE_NATIVE_mpn_sqr_basecase
     102  /* The limit of the generic code is SQR_TOOM2_THRESHOLD.  */
     103  #define SQR_BASECASE_LIM  SQR_TOOM2_THRESHOLD
     104  #endif
     105  
     106  #if HAVE_NATIVE_mpn_sqr_basecase
     107  #ifdef TUNE_SQR_TOOM2_MAX
     108  /* We slightly abuse TUNE_SQR_TOOM2_MAX here.  If it is set for an assembly
     109     mpn_sqr_basecase, it comes from SQR_TOOM2_THRESHOLD_MAX in the assembly
     110     file.  An assembly mpn_sqr_basecase that does not define it should allow
     111     any size.  */
     112  #define SQR_BASECASE_LIM  SQR_TOOM2_THRESHOLD
     113  #endif
     114  #endif
     115  
     116  #ifdef WANT_FAT_BINARY
     117  /* For fat builds, we use SQR_TOOM2_THRESHOLD which will expand to a read from
     118     __gmpn_cpuvec.  Perhaps any possible sqr_basecase.asm allow any size, and we
     119     limit the use unnecessarily.  We cannot tell, so play it safe.  FIXME.  */
     120  #define SQR_BASECASE_LIM  SQR_TOOM2_THRESHOLD
     121  #endif
     122  
     123  #ifndef SQR_BASECASE_LIM
     124  /* If SQR_BASECASE_LIM is now not defined, use mpn_sqr_basecase for any operand
     125     size.  */
     126  #define SQR_BASECASE_LIM  MP_SIZE_T_MAX
     127  #endif
     128  
     129  #define mpn_local_sqr(rp,up,n)						\
     130    do {									\
     131      if (ABOVE_THRESHOLD (n, SQR_BASECASE_THRESHOLD)			\
     132  	&& BELOW_THRESHOLD (n, SQR_BASECASE_LIM))			\
     133        mpn_sqr_basecase (rp, up, n);					\
     134      else								\
     135        mpn_mul_basecase(rp, up, n, up, n);				\
     136    } while (0)
     137  
     138  #define getbit(p,bi) \
     139    ((p[(bi - 1) / GMP_NUMB_BITS] >> (bi - 1) % GMP_NUMB_BITS) & 1)
     140  
     141  /* FIXME: Maybe some things would get simpler if all callers ensure
     142     that bi >= nbits. As far as I understand, with the current code bi
     143     < nbits can happen only for the final iteration. */
     144  static inline mp_limb_t
     145  getbits (const mp_limb_t *p, mp_bitcnt_t bi, int nbits)
     146  {
     147    int nbits_in_r;
     148    mp_limb_t r;
     149    mp_size_t i;
     150  
     151    if (bi < nbits)
     152      {
     153        return p[0] & (((mp_limb_t) 1 << bi) - 1);
     154      }
     155    else
     156      {
     157        bi -= nbits;			/* bit index of low bit to extract */
     158        i = bi / GMP_NUMB_BITS;		/* word index of low bit to extract */
     159        bi %= GMP_NUMB_BITS;		/* bit index in low word */
     160        r = p[i] >> bi;			/* extract (low) bits */
     161        nbits_in_r = GMP_NUMB_BITS - bi;	/* number of bits now in r */
     162        if (nbits_in_r < nbits)		/* did we get enough bits? */
     163  	r += p[i + 1] << nbits_in_r;	/* prepend bits from higher word */
     164        return r & (((mp_limb_t ) 1 << nbits) - 1);
     165      }
     166  }
     167  
     168  #ifndef POWM_SEC_TABLE
     169  #if GMP_NUMB_BITS < 50
     170  #define POWM_SEC_TABLE  2,33,96,780,2741
     171  #else
     172  #define POWM_SEC_TABLE  2,130,524,2578
     173  #endif
     174  #endif
     175  
     176  #if TUNE_PROGRAM_BUILD
     177  extern int win_size (mp_bitcnt_t);
     178  #else
     179  static inline int
     180  win_size (mp_bitcnt_t enb)
     181  {
     182    int k;
     183    /* Find k, such that x[k-1] < enb <= x[k].
     184  
     185       We require that x[k] >= k, then it follows that enb > x[k-1] >=
     186       k-1, which implies k <= enb.
     187    */
     188    static const mp_bitcnt_t x[] = {POWM_SEC_TABLE,~(mp_bitcnt_t)0};
     189    for (k = 0; enb > x[k++]; )
     190      ;
     191    ASSERT (k <= enb);
     192    return k;
     193  }
     194  #endif
     195  
     196  /* Convert U to REDC form, U_r = B^n * U mod M.
     197     Uses scratch space at tp of size 2un + n + 1.  */
     198  static void
     199  redcify (mp_ptr rp, mp_srcptr up, mp_size_t un, mp_srcptr mp, mp_size_t n, mp_ptr tp)
     200  {
     201    MPN_ZERO (tp, n);
     202    MPN_COPY (tp + n, up, un);
     203  
     204    mpn_sec_div_r (tp, un + n, mp, n, tp + un + n);
     205    MPN_COPY (rp, tp, n);
     206  }
     207  
     208  static mp_limb_t
     209  sec_binvert_limb (mp_limb_t n)
     210  {
     211    mp_limb_t inv, t;
     212    ASSERT ((n & 1) == 1);
     213    /* 3 + 2 -> 5 */
     214    inv = n + (((n + 1) << 1) & 0x18);
     215  
     216    t = n * inv;
     217  #if GMP_NUMB_BITS <= 10
     218    /* 5 x 2 -> 10 */
     219    inv = 2 * inv - inv * t;
     220  #else /* GMP_NUMB_BITS > 10 */
     221    /* 5 x 2 + 2 -> 12 */
     222    inv = 2 * inv - inv * t + ((inv<<10)&-(t&(1<<5)));
     223  #endif /* GMP_NUMB_BITS <= 10 */
     224  
     225    if (GMP_NUMB_BITS > 12)
     226      {
     227        t = n * inv - 1;
     228        if (GMP_NUMB_BITS <= 36)
     229  	{
     230  	  /* 12 x 3 -> 36 */
     231  	  inv += inv * t * (t - 1);
     232  	}
     233        else /* GMP_NUMB_BITS > 36 */
     234  	{
     235  	  mp_limb_t t2 = t * t;
     236  #if GMP_NUMB_BITS <= 60
     237  	  /* 12 x 5 -> 60 */
     238  	  inv += inv * (t2 + 1) * (t2 - t);
     239  #else /* GMP_NUMB_BITS > 60 */
     240  	  /* 12 x 5 + 4 -> 64 */
     241  	  inv *= (t2 + 1) * (t2 - t) + 1 - ((t<<48)&-(t&(1<<12)));
     242  
     243  	  /* 64 -> 128 -> 256 -> ... */
     244  	  for (int todo = (GMP_NUMB_BITS - 1) >> 6; todo != 0; todo >>= 1)
     245  	    inv = 2 * inv - inv * inv * n;
     246  #endif /* GMP_NUMB_BITS <= 60 */
     247  	}
     248      }
     249  
     250    ASSERT ((inv * n & GMP_NUMB_MASK) == 1);
     251    return inv & GMP_NUMB_MASK;
     252  }
     253  
     254  /* {rp, n} <-- {bp, bn} ^ {ep, en} mod {mp, n},
     255     where en = ceil (enb / GMP_NUMB_BITS)
     256     Requires that {mp, n} is odd (and hence also mp[0] odd).
     257     Uses scratch space at tp as defined by mpn_sec_powm_itch.  */
     258  void
     259  mpn_sec_powm (mp_ptr rp, mp_srcptr bp, mp_size_t bn,
     260  	      mp_srcptr ep, mp_bitcnt_t enb,
     261  	      mp_srcptr mp, mp_size_t n, mp_ptr tp)
     262  {
     263    mp_limb_t ip[2], *mip;
     264    int windowsize, this_windowsize;
     265    mp_limb_t expbits;
     266    mp_ptr pp, this_pp, ps;
     267    long i;
     268    int cnd;
     269  
     270    ASSERT (enb > 0);
     271    ASSERT (n > 0);
     272    /* The code works for bn = 0, but the defined scratch space is 2 limbs
     273       greater than we supply, when converting 1 to redc form .  */
     274    ASSERT (bn > 0);
     275    ASSERT ((mp[0] & 1) != 0);
     276  
     277    windowsize = win_size (enb);
     278  
     279    mip = ip;
     280    mip[0] = sec_binvert_limb (mp[0]);
     281    if (ABOVE_THRESHOLD (n, REDC_1_TO_REDC_2_THRESHOLD))
     282      {
     283        mp_limb_t t, dummy, mip0 = mip[0];
     284  
     285        umul_ppmm (t, dummy, mip0, mp[0]);
     286        ASSERT (dummy == 1);
     287        t += mip0 * mp[1]; /* t = (mp * mip0)[1] */
     288  
     289        mip[1] = t * mip0 - 1; /* ~( - t * mip0) */
     290      }
     291    mip[0] = -mip[0];
     292  
     293    pp = tp;
     294    tp += (n << windowsize);	/* put tp after power table */
     295  
     296    /* Compute pp[0] table entry */
     297    /* scratch: |   n   | 1 |   n+2    |  */
     298    /*          | pp[0] | 1 | redcify  |  */
     299    this_pp = pp;
     300    this_pp[n] = 1;
     301    redcify (this_pp, this_pp + n, 1, mp, n, this_pp + n + 1);
     302    this_pp += n;
     303  
     304    /* Compute pp[1] table entry.  To avoid excessive scratch usage in the
     305       degenerate situation where B >> M, we let redcify use scratch space which
     306       will later be used by the pp table (element 2 and up).  */
     307    /* scratch: |   n   |   n   |  bn + n + 1  |  */
     308    /*          | pp[0] | pp[1] |   redcify    |  */
     309    redcify (this_pp, bp, bn, mp, n, this_pp + n);
     310  
     311    /* Precompute powers of b and put them in the temporary area at pp.  */
     312    /* scratch: |   n   |   n   | ...  |                    |   2n      |  */
     313    /*          | pp[0] | pp[1] | ...  | pp[2^windowsize-1] |  product  |  */
     314    ps = pp + n;		/* initially B^1 */
     315    if (BELOW_THRESHOLD (n, REDC_1_TO_REDC_2_THRESHOLD))
     316      {
     317        for (i = (1 << windowsize) - 2; i > 0; i -= 2)
     318  	{
     319  	  mpn_local_sqr (tp, ps, n);
     320  	  ps += n;
     321  	  this_pp += n;
     322  	  MPN_REDC_1_SEC (this_pp, tp, mp, n, mip[0]);
     323  
     324  	  mpn_mul_basecase (tp, this_pp, n, pp + n, n);
     325  	  this_pp += n;
     326  	  MPN_REDC_1_SEC (this_pp, tp, mp, n, mip[0]);
     327  	}
     328      }
     329    else
     330      {
     331        for (i = (1 << windowsize) - 2; i > 0; i -= 2)
     332  	{
     333  	  mpn_local_sqr (tp, ps, n);
     334  	  ps += n;
     335  	  this_pp += n;
     336  	  MPN_REDC_2_SEC (this_pp, tp, mp, n, mip);
     337  
     338  	  mpn_mul_basecase (tp, this_pp, n, pp + n, n);
     339  	  this_pp += n;
     340  	  MPN_REDC_2_SEC (this_pp, tp, mp, n, mip);
     341  	}
     342      }
     343  
     344    expbits = getbits (ep, enb, windowsize);
     345    ASSERT_ALWAYS (enb >= windowsize);
     346    enb -= windowsize;
     347  
     348    mpn_sec_tabselect (rp, pp, n, 1 << windowsize, expbits);
     349  
     350    /* Main exponentiation loop.  */
     351    /* scratch: |   n   |   n   | ...  |                    |     3n-4n     |  */
     352    /*          | pp[0] | pp[1] | ...  | pp[2^windowsize-1] |  loop scratch |  */
     353  
     354  #define INNERLOOP							\
     355    while (enb != 0)							\
     356      {									\
     357        expbits = getbits (ep, enb, windowsize);				\
     358        this_windowsize = windowsize;					\
     359        if (enb < windowsize)						\
     360  	{								\
     361  	  this_windowsize -= windowsize - enb;				\
     362  	  enb = 0;							\
     363  	}								\
     364        else								\
     365  	enb -= windowsize;						\
     366  									\
     367        do								\
     368  	{								\
     369  	  mpn_local_sqr (tp, rp, n);					\
     370  	  MPN_REDUCE (rp, tp, mp, n, mip);				\
     371  	  this_windowsize--;						\
     372  	}								\
     373        while (this_windowsize != 0);					\
     374  									\
     375        mpn_sec_tabselect (tp + 2*n, pp, n, 1 << windowsize, expbits);	\
     376        mpn_mul_basecase (tp, rp, n, tp + 2*n, n);			\
     377  									\
     378        MPN_REDUCE (rp, tp, mp, n, mip);					\
     379      }
     380  
     381    if (BELOW_THRESHOLD (n, REDC_1_TO_REDC_2_THRESHOLD))
     382      {
     383  #undef MPN_REDUCE
     384  #define MPN_REDUCE(rp,tp,mp,n,mip)	MPN_REDC_1_SEC (rp, tp, mp, n, mip[0])
     385        INNERLOOP;
     386      }
     387    else
     388      {
     389  #undef MPN_REDUCE
     390  #define MPN_REDUCE(rp,tp,mp,n,mip)	MPN_REDC_2_SEC (rp, tp, mp, n, mip)
     391        INNERLOOP;
     392      }
     393  
     394    MPN_COPY (tp, rp, n);
     395    MPN_ZERO (tp + n, n);
     396  
     397    if (BELOW_THRESHOLD (n, REDC_1_TO_REDC_2_THRESHOLD))
     398      MPN_REDC_1_SEC (rp, tp, mp, n, mip[0]);
     399    else
     400      MPN_REDC_2_SEC (rp, tp, mp, n, mip);
     401  
     402    cnd = mpn_sub_n (tp, rp, mp, n);	/* we need just retval */
     403    mpn_cnd_sub_n (!cnd, rp, rp, mp, n);
     404  }
     405  
     406  mp_size_t
     407  mpn_sec_powm_itch (mp_size_t bn, mp_bitcnt_t enb, mp_size_t n)
     408  {
     409    int windowsize;
     410    mp_size_t redcify_itch, itch;
     411  
     412    /* FIXME: no more _local/_basecase difference. */
     413    /* The top scratch usage will either be when reducing B in the 2nd redcify
     414       call, or more typically n*2^windowsize + 3n or 4n, in the main loop.  (It
     415       is 3n or 4n depending on if we use mpn_local_sqr or a native
     416       mpn_sqr_basecase.  We assume 4n always for now.) */
     417  
     418    windowsize = win_size (enb);
     419  
     420    /* The 2n term is due to pp[0] and pp[1] at the time of the 2nd redcify call,
     421       the (bn + n) term is due to redcify's own usage, and the rest is due to
     422       mpn_sec_div_r's usage when called from redcify.  */
     423    redcify_itch = (2 * n) + (bn + n) + ((bn + n) + 2 * n + 2);
     424  
     425    /* The n * 2^windowsize term is due to the power table, the 4n term is due to
     426       scratch needs of squaring/multiplication in the exponentiation loop.  */
     427    itch = (n << windowsize) + (4 * n);
     428  
     429    return MAX (itch, redcify_itch);
     430  }