(a * b) / c MulDiv and dealing with overflow from intermediate multiplication

旧街凉风 提交于 2019-12-01 03:48:38

I've been tinkering with an approach that (1) multiplies a and b with the school algorithm on 21-bit limbs (2) proceeds to do long division by c, with an unusual representation of the residual a*b - c*q that uses a double to store the high-order bits and a long to store the low-order bits. I don't know if it can be made to be competitive with standard long division, but for your enjoyment,

public class MulDiv {
  public static void main(String[] args) {
    java.util.Random r = new java.util.Random();
    for (long i = 0; true; i++) {
      if (i % 1000000 == 0) {
        System.err.println(i);
      }
      long a = r.nextLong() >> (r.nextInt(8) * 8);
      long b = r.nextLong() >> (r.nextInt(8) * 8);
      long c = r.nextLong() >> (r.nextInt(8) * 8);
      if (c == 0) {
        continue;
      }
      long x = mulDiv(a, b, c);
      java.math.BigInteger aa = java.math.BigInteger.valueOf(a);
      java.math.BigInteger bb = java.math.BigInteger.valueOf(b);
      java.math.BigInteger cc = java.math.BigInteger.valueOf(c);
      java.math.BigInteger xx = aa.multiply(bb).divide(cc);
      if (java.math.BigInteger.valueOf(xx.longValue()).equals(xx) && x != xx.longValue()) {
        System.out.printf("a=%d b=%d c=%d: %d != %s\n", a, b, c, x, xx);
      }
    }
  }

  // Returns truncate(a b/c), subject to the precondition that the result is
  // defined and can be represented as a long.
  private static long mulDiv(long a, long b, long c) {
    // Decompose a.
    long a2 = a >> 42;
    long a10 = a - (a2 << 42);
    long a1 = a10 >> 21;
    long a0 = a10 - (a1 << 21);
    assert a == (((a2 << 21) + a1) << 21) + a0;
    // Decompose b.
    long b2 = b >> 42;
    long b10 = b - (b2 << 42);
    long b1 = b10 >> 21;
    long b0 = b10 - (b1 << 21);
    assert b == (((b2 << 21) + b1) << 21) + b0;
    // Compute a b.
    long ab4 = a2 * b2;
    long ab3 = a2 * b1 + a1 * b2;
    long ab2 = a2 * b0 + a1 * b1 + a0 * b2;
    long ab1 = a1 * b0 + a0 * b1;
    long ab0 = a0 * b0;
    // Compute a b/c.
    DivBy d = new DivBy(c);
    d.shift21Add(ab4);
    d.shift21Add(ab3);
    d.shift21Add(ab2);
    d.shift21Add(ab1);
    d.shift21Add(ab0);
    return d.getQuotient();
  }
}

public strictfp class DivBy {
  // Initializes n <- 0.
  public DivBy(long d) {
    di = d;
    df = (double) d;
    oneOverD = 1.0 / df;
  }

  // Updates n <- 2^21 n + i. Assumes |i| <= 3 (2^42).
  public void shift21Add(long i) {
    // Update the quotient and remainder.
    q <<= 21;
    ri = (ri << 21) + i;
    rf = rf * (double) (1 << 21) + (double) i;
    reduce();
  }

  // Returns truncate(n/d).
  public long getQuotient() {
    while (rf != (double) ri) {
      reduce();
    }
    // Round toward zero.
    if (q > 0) {
      if ((di > 0 && ri < 0) || (di < 0 && ri > 0)) {
        return q - 1;
      }
    } else if (q < 0) {
      if ((di > 0 && ri > 0) || (di < 0 && ri < 0)) {
        return q + 1;
      }
    }
    return q;
  }

  private void reduce() {
    // x is approximately r/d.
    long x = Math.round(rf * oneOverD);
    q += x;
    ri -= di * x;
    rf = repairLowOrderBits(rf - df * (double) x, ri);
  }

  private static double repairLowOrderBits(double f, long i) {
    int e = Math.getExponent(f);
    if (e < 64) {
      return (double) i;
    }
    long rawBits = Double.doubleToRawLongBits(f);
    long lowOrderBits = (rawBits >> 63) ^ (rawBits << (e - 52));
    return f + (double) (i - lowOrderBits);
  }

  private final long di;
  private final double df;
  private final double oneOverD;
  private long q = 0;
  private long ri = 0;
  private double rf = 0;
}

You can use the greatest common divisor (gcd) to help.

a * b / c = (a / gcd(a,c)) * (b / (c / gcd(a,c)))

Edit: The OP asked me to explain the above equation. Basically, we have:

a = (a / gcd(a,c)) * gcd(a,c)
c = (c / gcd(a,c)) * gcd(a,c)

Let's say x=gcd(a,c) for brevity, and rewrite this.

a*b/c = (a/x) * x * b 
        --------------
        (c/x) * x

Next, we cancel

a*b/c = (a/x) * b 
        ----------
        (c/x) 

You can take this a step further. Let y = gcd(b, c/x)

a*b/c = (a/x) * (b/y) * y 
        ------------------
        ((c/x)/y) * y 

a*b/c = (a/x) * (b/y) 
        ------------
           (c/(xy))

Here's code to get the gcd.

static long gcd(long a, long b) 
{ 
  if (b == 0) 
    return a; 
  return gcd(b, a % b);  
} 

David Eisenstat got me thinking some more.
I want simple cases to be fast: let double take care of that. Newton-Raphson may be a better choice for the rest.

 /** Multiplies both <code>factor</code>s
  *  and divides by <code>divisor</code>.
  * @return <code>Long.MIN_VALUE</code> if result out of range,<br/>
  *     else <code>factorA * factor1 / divisor</code> */
    public static long
    mulDiv(long factorA, long factor1, long divisor) {
        final double dd = divisor,
            product = (double)factorA * factor1,
            a1_d = product / dd;
        if (a1_d < -TOO_LARGE || TOO_LARGE < a1_d)
            return tooLarge();
        if (-ONE_ < a1_d && a1_d < ONE_)
            return 0;
        if (-EXACT < product && product < EXACT)
            return (long) a1_d;
        long pLo = factorA * factor1, //diff,
            pHi = high64(factorA, factor1);
        if (a1_d < -LONG_MAX_ || LONG_MAX_ < a1_d) {
            long maxdHi = divisor >> 1;
            if (maxdHi < pHi
                || maxdHi == pHi
                   && Long.compareUnsigned((divisor << Long.SIZE-1),
                                           pLo) <= 0)
                return tooLarge();
        }
        final double high_dd = TWO_POWER64/dd;
        long quotient = (long) a1_d,
            loPP = quotient * divisor,
            hiPP = high64(quotient, divisor);
        long remHi = pHi - hiPP, // xxx overflow/carry
            remLo = pLo - loPP;
        if (Long.compareUnsigned(pLo, remLo) < 0)
            remHi -= 1;
        double fudge = remHi * high_dd;
        if (remLo < 0)
            fudge += high_dd;
        fudge += remLo/dd;
        long //fHi = (long)fudge/TWO_POWER64,
            fLo = (long) Math.floor(fudge); //*round
        quotient += fLo;
        loPP = quotient * divisor;
        hiPP = high64(quotient, divisor);
        remHi = pHi - hiPP; // should be 0?!
        remLo = pLo - loPP;
        if (Long.compareUnsigned(pLo, remLo) < 0)
            remHi -= 1;
        if (0 == remHi && 0 <= remLo && remLo < divisor)
            return quotient;

        fudge = remHi * high_dd;
        if (remLo < 0)
            fudge += high_dd;
        fudge += remLo/dd;
        fLo = (long) Math.floor(fudge);
        return quotient + fLo;
    }

 /** max <code>double</code> trusted to represent
  *  a value in the range of <code>long</code> */
    static final double
        LONG_MAX_ = Double.valueOf(Long.MAX_VALUE - 0xFFF);
 /** max <code>double</code> trusted to represent a value below 1 */
    static final double
        ONE_ = Double.longBitsToDouble(
                    Double.doubleToRawLongBits(1) - 4);
 /** max <code>double</code> trusted to represent a value exactly */
    static final double
        EXACT = Long.MAX_VALUE >> 12;
    static final double
        TWO_POWER64 = Double.valueOf(1L<<32)*Double.valueOf(1L<<32);

    static long tooLarge() {
//      throw new RuntimeException("result too large for long");
        return Long.MIN_VALUE;
    }
    static final long   ONES_32 = ~(~0L << 32);

    static long high64(long factorA, long factor1) {
        long loA = factorA & ONES_32,
            hiA = factorA >>> 32,
            lo1 = factor1 & ONES_32,
            hi1 = factor1 >>> 32;
        return ((loA * lo1 >>> 32)
                +loA * hi1 + hiA * lo1 >>> 32)
               + hiA * hi1;
    }

(I rearranged this code some out of the IDE to have mulDiv() on top. Being lazy, I have a wrapper for sign handling - might try and do it properly before hell freezes over.
For timing, a model of input is in dire need:
How about such that each result possible is equally likely?)

Divide a/c and b/c into whole and fractional (remainder) parts, then you have:

a*b/c 
= c * a/c * b/c 
= c * (x/c + y/c) * (z/c + w/c)
= xz/c + xw/c + yz/c + yw/c where x and z are multiples of c

As such, you can trivially calculate the first three factors without overflow. In my experience, this is often enough to cover typical overflow cases. However, if your divisor is too large, such that (a % c) * (b % c) overflows, this method still fails. If that's a typical issue for you, you may want to look at other approaches (e.g. dividing both the biggest of a and b as well as c by 2 until you have no overflows anymore, but how to do that without introducing additional error due to biases in the process is non-trivial -- you'll need to keep a running score of the error in a separate variable, probably)

Anyway, the code for the above:

long a,b,c;
long bMod = (b % c)
long result = a * (b / c) + (a / c) * bMod + ((a % c) * bMod) / c;

If speed is a big concern (I'm assuming it is at least to some extent, since you're asking this), you may want to consider storing a/c and b/c in variables and calculating the mod through multiplication, e.g. replace (a % c) by (a - aDiv * c) -- this allows you to go from 4 divisions per call to 2.

You assume the following:

long a,b,c;
long result = a*b/c;
  • All 3 operands are of type long
  • The result is of type long
  • a * b may be greater and not suit into type long

Mathematical speaking:

(a * b) / c = (a / c) * b = a * (b / c)
  • a / c is surely of type long
  • b / c is surely of type long

As long as your assumption is correct (result is of type long), you need to divide the larger of (a) and (b) by (c) and do the multiplication afterwards to receive a result which is not bigger the type long.

But:

The type long holds no decimals. Therefore we need to save the remainder of the division as well.

(a * b) / c = (a / c) * b + (a % c) * b

We assume that (a % c) * b gets us a clear long-value and not a double-value. Alternatively we can use:

(a * b) / c = (b / c) * a + (b % c) * a

We assume that (b % c) * a holds no decimals.

Nevertheless @Jesper is right. As long as you don't plan to do this calculation several millions of times you should be fine with existing big types.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!