(root)/
gmp-6.3.0/
mpn/
generic/
toom32_mul.c
       1  /* mpn_toom32_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 1.5
       2     times as large as bn.  Or more accurately, bn < an < 3bn.
       3  
       4     Contributed to the GNU project by Torbjorn Granlund.
       5     Improvements by Marco Bodrato and Niels Möller.
       6  
       7     The idea of applying Toom to unbalanced multiplication is due to Marco
       8     Bodrato and Alberto Zanoni.
       9  
      10     THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE.  IT IS ONLY
      11     SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
      12     GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
      13  
      14  Copyright 2006-2010, 2020, 2021 Free Software Foundation, Inc.
      15  
      16  This file is part of the GNU MP Library.
      17  
      18  The GNU MP Library is free software; you can redistribute it and/or modify
      19  it under the terms of either:
      20  
      21    * the GNU Lesser General Public License as published by the Free
      22      Software Foundation; either version 3 of the License, or (at your
      23      option) any later version.
      24  
      25  or
      26  
      27    * the GNU General Public License as published by the Free Software
      28      Foundation; either version 2 of the License, or (at your option) any
      29      later version.
      30  
      31  or both in parallel, as here.
      32  
      33  The GNU MP Library is distributed in the hope that it will be useful, but
      34  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
      35  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
      36  for more details.
      37  
      38  You should have received copies of the GNU General Public License and the
      39  GNU Lesser General Public License along with the GNU MP Library.  If not,
      40  see https://www.gnu.org/licenses/.  */
      41  
      42  
      43  #include "gmp-impl.h"
      44  
      45  /* Evaluate in: -1, 0, +1, +inf
      46  
      47    <-s-><--n--><--n-->
      48     ___ ______ ______
      49    |a2_|___a1_|___a0_|
      50  	|_b1_|___b0_|
      51  	<-t--><--n-->
      52  
      53    v0  =  a0         * b0      #   A(0)*B(0)
      54    v1  = (a0+ a1+ a2)*(b0+ b1) #   A(1)*B(1)      ah  <= 2  bh <= 1
      55    vm1 = (a0- a1+ a2)*(b0- b1) #  A(-1)*B(-1)    |ah| <= 1  bh = 0
      56    vinf=          a2 *     b1  # A(inf)*B(inf)
      57  */
      58  
      59  #define TOOM32_MUL_N_REC(p, a, b, n, ws)				\
      60    do {									\
      61      mpn_mul_n (p, a, b, n);						\
      62    } while (0)
      63  
      64  void
      65  mpn_toom32_mul (mp_ptr pp,
      66  		mp_srcptr ap, mp_size_t an,
      67  		mp_srcptr bp, mp_size_t bn,
      68  		mp_ptr scratch)
      69  {
      70    mp_size_t n, s, t;
      71    int vm1_neg;
      72    mp_limb_t cy;
      73    mp_limb_signed_t hi;
      74    mp_limb_t ap1_hi, bp1_hi;
      75  
      76  #define a0  ap
      77  #define a1  (ap + n)
      78  #define a2  (ap + 2 * n)
      79  #define b0  bp
      80  #define b1  (bp + n)
      81  
      82    /* Required, to ensure that s + t >= n. */
      83    ASSERT (bn + 2 <= an && an + 6 <= 3*bn);
      84  
      85    n = 2 * an >= 3 * bn ? (an + 2) / (size_t) 3 : (bn + 1) >> 1;
      86  
      87    s = an - 2 * n;
      88    t = bn - n;
      89  
      90    ASSERT (0 < s && s <= n);
      91    ASSERT (0 < t && t <= n);
      92    ASSERT (s + t >= n);
      93  
      94    /* Product area of size an + bn = 3*n + s + t >= 4*n + 2. */
      95  #define ap1 (pp)		/* n, most significant limb in ap1_hi */
      96  #define bp1 (pp + n)		/* n, most significant bit in bp1_hi */
      97  #define am1 (pp + 2*n)		/* n, most significant bit in hi */
      98  #define bm1 (pp + 3*n)		/* n */
      99  #define v1 (scratch)		/* 2n + 1 */
     100  #define vm1 (pp)		/* 2n + 1 */
     101  #define scratch_out (scratch + 2*n + 1) /* Currently unused. */
     102  
     103    /* Scratch need: 2*n + 1 + scratch for the recursive multiplications. */
     104  
     105    /* FIXME: Keep v1[2*n] and vm1[2*n] in scalar variables? */
     106  
     107    /* Compute ap1 = a0 + a1 + a2, am1 = a0 - a1 + a2 */
     108    ap1_hi = mpn_add (ap1, a0, n, a2, s);
     109  #if HAVE_NATIVE_mpn_add_n_sub_n
     110    if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
     111      {
     112        ap1_hi = mpn_add_n_sub_n (ap1, am1, a1, ap1, n) >> 1;
     113        hi = 0;
     114        vm1_neg = 1;
     115      }
     116    else
     117      {
     118        cy = mpn_add_n_sub_n (ap1, am1, ap1, a1, n);
     119        hi = ap1_hi - (cy & 1);
     120        ap1_hi += (cy >> 1);
     121        vm1_neg = 0;
     122      }
     123  #else
     124    if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
     125      {
     126        ASSERT_NOCARRY (mpn_sub_n (am1, a1, ap1, n));
     127        hi = 0;
     128        vm1_neg = 1;
     129      }
     130    else
     131      {
     132        hi = ap1_hi - mpn_sub_n (am1, ap1, a1, n);
     133        vm1_neg = 0;
     134      }
     135    ap1_hi += mpn_add_n (ap1, ap1, a1, n);
     136  #endif
     137  
     138    /* Compute bp1 = b0 + b1 and bm1 = b0 - b1. */
     139    if (t == n)
     140      {
     141  #if HAVE_NATIVE_mpn_add_n_sub_n
     142        if (mpn_cmp (b0, b1, n) < 0)
     143  	{
     144  	  cy = mpn_add_n_sub_n (bp1, bm1, b1, b0, n);
     145  	  vm1_neg ^= 1;
     146  	}
     147        else
     148  	{
     149  	  cy = mpn_add_n_sub_n (bp1, bm1, b0, b1, n);
     150  	}
     151        bp1_hi = cy >> 1;
     152  #else
     153        bp1_hi = mpn_add_n (bp1, b0, b1, n);
     154  
     155        if (mpn_cmp (b0, b1, n) < 0)
     156  	{
     157  	  ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, n));
     158  	  vm1_neg ^= 1;
     159  	}
     160        else
     161  	{
     162  	  ASSERT_NOCARRY (mpn_sub_n (bm1, b0, b1, n));
     163  	}
     164  #endif
     165      }
     166    else
     167      {
     168        /* FIXME: Should still use mpn_add_n_sub_n for the main part. */
     169        bp1_hi = mpn_add (bp1, b0, n, b1, t);
     170  
     171        if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0)
     172  	{
     173  	  ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, t));
     174  	  MPN_ZERO (bm1 + t, n - t);
     175  	  vm1_neg ^= 1;
     176  	}
     177        else
     178  	{
     179  	  ASSERT_NOCARRY (mpn_sub (bm1, b0, n, b1, t));
     180  	}
     181      }
     182  
     183    TOOM32_MUL_N_REC (v1, ap1, bp1, n, scratch_out);
     184    if (ap1_hi == 1)
     185      {
     186        cy = mpn_add_n (v1 + n, v1 + n, bp1, n);
     187      }
     188    else if (ap1_hi > 1) /* ap1_hi == 2 */
     189      {
     190  #if HAVE_NATIVE_mpn_addlsh1_n_ip1
     191        cy = mpn_addlsh1_n_ip1 (v1 + n, bp1, n);
     192  #else
     193        cy = mpn_addmul_1 (v1 + n, bp1, n, CNST_LIMB(2));
     194  #endif
     195      }
     196    else
     197      cy = 0;
     198    if (bp1_hi != 0)
     199      cy += ap1_hi + mpn_add_n (v1 + n, v1 + n, ap1, n);
     200    v1[2 * n] = cy;
     201  
     202    TOOM32_MUL_N_REC (vm1, am1, bm1, n, scratch_out);
     203    if (hi)
     204      hi = mpn_add_n (vm1+n, vm1+n, bm1, n);
     205  
     206    vm1[2*n] = hi;
     207  
     208    /* v1 <-- (v1 + vm1) / 2 = x0 + x2 */
     209    if (vm1_neg)
     210      {
     211  #if HAVE_NATIVE_mpn_rsh1sub_n
     212        mpn_rsh1sub_n (v1, v1, vm1, 2*n+1);
     213  #else
     214        mpn_sub_n (v1, v1, vm1, 2*n+1);
     215        ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
     216  #endif
     217      }
     218    else
     219      {
     220  #if HAVE_NATIVE_mpn_rsh1add_n
     221        mpn_rsh1add_n (v1, v1, vm1, 2*n+1);
     222  #else
     223        mpn_add_n (v1, v1, vm1, 2*n+1);
     224        ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
     225  #endif
     226      }
     227  
     228    /* We get x1 + x3 = (x0 + x2) - (x0 - x1 + x2 - x3), and hence
     229  
     230       y = x1 + x3 + (x0 + x2) * B
     231         = (x0 + x2) * B + (x0 + x2) - vm1.
     232  
     233       y is 3*n + 1 limbs, y = y0 + y1 B + y2 B^2. We store them as
     234       follows: y0 at scratch, y1 at pp + 2*n, and y2 at scratch + n
     235       (already in place, except for carry propagation).
     236  
     237       We thus add
     238  
     239     B^3  B^2   B    1
     240      |    |    |    |
     241     +-----+----+
     242   + |  x0 + x2 |
     243     +----+-----+----+
     244   +      |  x0 + x2 |
     245  	+----------+
     246   -      |  vm1     |
     247   --+----++----+----+-
     248     | y2  | y1 | y0 |
     249     +-----+----+----+
     250  
     251    Since we store y0 at the same location as the low half of x0 + x2, we
     252    need to do the middle sum first. */
     253  
     254    hi = vm1[2*n];
     255    cy = mpn_add_n (pp + 2*n, v1, v1 + n, n);
     256    MPN_INCR_U (v1 + n, n + 1, cy + v1[2*n]);
     257  
     258    /* FIXME: Can we get rid of this second vm1_neg conditional by
     259       swapping the location of +1 and -1 values? */
     260    if (vm1_neg)
     261      {
     262        cy = mpn_add_n (v1, v1, vm1, n);
     263        hi += mpn_add_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
     264        MPN_INCR_U (v1 + n, n+1, hi);
     265      }
     266    else
     267      {
     268        cy = mpn_sub_n (v1, v1, vm1, n);
     269        hi += mpn_sub_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
     270        MPN_DECR_U (v1 + n, n+1, hi);
     271      }
     272  
     273    TOOM32_MUL_N_REC (pp, a0, b0, n, scratch_out);
     274    /* vinf, s+t limbs.  Use mpn_mul for now, to handle unbalanced operands */
     275    if (s > t)  mpn_mul (pp+3*n, a2, s, b1, t);
     276    else        mpn_mul (pp+3*n, b1, t, a2, s);
     277  
     278    /* Remaining interpolation.
     279  
     280       y * B + x0 + x3 B^3 - x0 B^2 - x3 B
     281       = (x1 + x3) B + (x0 + x2) B^2 + x0 + x3 B^3 - x0 B^2 - x3 B
     282       = y0 B + y1 B^2 + y3 B^3 + Lx0 + H x0 B
     283         + L x3 B^3 + H x3 B^4 - Lx0 B^2 - H x0 B^3 - L x3 B - H x3 B^2
     284       = L x0 + (y0 + H x0 - L x3) B + (y1 - L x0 - H x3) B^2
     285         + (y2 - (H x0 - L x3)) B^3 + H x3 B^4
     286  
     287  	  B^4       B^3       B^2        B         1
     288   |         |         |         |         |         |
     289     +-------+                   +---------+---------+
     290     |  Hx3  |                   | Hx0-Lx3 |    Lx0  |
     291     +------+----------+---------+---------+---------+
     292  	  |    y2    |  y1     |   y0    |
     293  	  ++---------+---------+---------+
     294  	  -| Hx0-Lx3 | - Lx0   |
     295  	   +---------+---------+
     296  		      | - Hx3  |
     297  		      +--------+
     298  
     299      We must take into account the carry from Hx0 - Lx3.
     300    */
     301  
     302    cy = mpn_sub_n (pp + n, pp + n, pp+3*n, n);
     303    hi = scratch[2*n] + cy;
     304  
     305    cy = mpn_sub_nc (pp + 2*n, pp + 2*n, pp, n, cy);
     306    hi -= mpn_sub_nc (pp + 3*n, scratch + n, pp + n, n, cy);
     307  
     308    hi += mpn_add (pp + n, pp + n, 3*n, scratch, n);
     309  
     310    /* FIXME: Is support for s + t == n needed? */
     311    if (LIKELY (s + t > n))
     312      {
     313        hi -= mpn_sub (pp + 2*n, pp + 2*n, 2*n, pp + 4*n, s+t-n);
     314  
     315        ASSERT (hi >= 0); /* contribution of the middle terms >= 0 */
     316        MPN_INCR_U (pp + 4*n, s+t-n, hi);
     317      }
     318    else
     319      ASSERT (hi == 0);
     320  }