How many 64-bit multiplications are needed to calculate the low 128-bits of a 64-bit by 128-bit product?

前端 未结 2 1810
旧时难觅i
旧时难觅i 2020-12-17 15:32

Consider that you want to calculate the low 128-bits of the result of multiplying a 64-bit and 128-bit unsigned number, and that the largest multiplication you have availabl

2条回答
  •  心在旅途
    2020-12-17 16:02

    Of course, without Karatsuba, 5 multiplies.

    Karatsuba is wonderful, but these days a 64 x 64 multiply can be over in 3 clocks and a new one can be scheduled every clock. So the overhead of dealing with the signs and what not can be significantly greater than the saving of one multiply.

    For straightforward 64 x 64 multiply need:

        r0 =       a0*b0
        r1 =    a0*b1
        r2 =    a1*b0
        r3 = a1*b1
    
      where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
                and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry
    
      where the carry is the carry from the additions to r0, and result is r3:r0.
    
    typedef struct { uint64_t w0, w1 ; } uint64x2_t ;
    
    uint64x2_t
    mulu64x2(uint64_t x, uint64_t m)
    {
      uint64x2_t r ;
      uint64_t r1, r2, rx, ry ;
      uint32_t x1, x0 ;
      uint32_t m1, m0 ;
    
      x1    = (uint32_t)(x >> 32) ;
      x0    = (uint32_t)x ;
      m1    = (uint32_t)(m >> 32) ;
      m0    = (uint32_t)m ;
    
      r1    = (uint64_t)x1 * m0 ;
      r2    = (uint64_t)x0 * m1 ;
      r.w0  = (uint64_t)x0 * m0 ;
      r.w1  = (uint64_t)x1 * m1 ;
    
      rx    = (uint32_t)r1 ;
      rx    = rx + (uint32_t)r2 ;    // add the ls halves, collecting carry
      ry    = r.w0 >> 32 ;           // pick up ms of r0
      r.w0 += (rx << 32) ;           // complete r0
      rx   += ry ;                   // complete addition, rx >> 32 == carry !
    
      r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;
    
      return r ;
    }
    

    For Karatsuba, the suggested:

    z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
    

    is trickier than it looks... for a start, if z1 is 64 bits, then need to somehow collect the carry which this addition can generate... and that is complicated by the signed-ness issues.

        z0 =       a0*b0
        z1 =    ax*bx        -- ax = (a1 - a0), bx = (b0 - b1)
        z2 = a1*b1
    
      where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
                and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry
    
      where the carry is the carry from the additions to create r0, and result is r1:r0.
    
      where must take into account the signed-ness of ax, bx and z1. 
    
    uint64x2_t
    mulu64x2_karatsuba(uint64_t a, uint64_t b)
    {
      uint64_t a0, a1, b0, b1 ;
      uint64_t ax, bx, zx, zy ;
      uint     as, bs, xs ;
      uint64_t z0, z2 ;
      uint64x2_t r ;
    
      a0 = (uint32_t)a ; a1 = a >> 32 ;
      b0 = (uint32_t)b ; b1 = b >> 32 ;
    
      z0 = a0 * b0 ;
      z2 = a1 * b1 ;
    
      ax = (uint64_t)(a1 - a0) ;
      bx = (uint64_t)(b0 - b1) ;
    
      as = (uint)(ax > a1) ;                // sign of magic middle, a
      bs = (uint)(bx > b0) ;                // sign of magic middle, b
      xs = (uint)(as ^ bs) ;                // sign of magic middle, x = a * b
    
      ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ;  // abs magic middle a
      bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ;  // abs magic middle b
    
      zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
      xs = xs & (uint)(zx != 0) ;           // discard sign if z1 == 0 !
    
      zy = (uint32_t)zx ;                   // start ls half of z1
      zy = zy + (uint32_t)z0 + (uint32_t)z2 ;
    
      r.w0 = z0 + (zy << 32) ;              // complete ls word of result.
      zy   = zy + (z0 >> 32) ;              // complete carry
    
      zx   = (zx >> 32) - ((uint64_t)xs << 32) ;   // start ms half of z1
      r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;
    
      return r ;
    }
    

    I did some very simple timings (using times(), running on Ryzen 7 1800X):

    • using gcc __int128................... ~780 'units'
    • using mulu64x2()..................... ~895
    • using mulu64x2_karatsuba()... ~1,095

    ...so, yes, you can save a multiply by using Karatsuba, but whether it's worth doing rather depends.

提交回复
热议问题