Compute (a*b)%m FAST for 64-bit unsigned arguments in C(++) on x86-64 platforms?

大憨熊 提交于 2019-12-01 07:33:19

问题


I'm looking for a fast method to efficiently compute  (ab) modulo n  (in the mathematical sense of that) for a, b, n of type uint64_t. I could live with preconditions such as n!=0, or even a<n && b<n.

Notice that the C expression (a*b)%n won't cut it, because the product is truncated to 64 bits. I'm looking for (uint64_t)(((uint128_t)a*b)%n) except that I do not have a uint128_t (that I know, in Visual C++).

I'm in for a Visual C++ (preferably) or GCC/clang intrinsic making best use of the underlying hardware available on x86-64 platforms; or if that can't be done for a portable inline function.


回答1:


Ok, how about this (not tested)

modmul:
; rcx = a
; rdx = b
; r8 = n
mov rax, rdx
mul rcx
div r8
mov rax, rdx
ret

The precondition is that a * b / n <= ~0ULL, otherwise there will be a divide error. That's a slightly less strict condition than a < n && m < n, one of them can be bigger than n as long as the other is small enough.

Unfortunately it has to be assembled and linked in separately, because MSVC doesn't support inline asm for 64bit targets.

It's also still slow, the real problem is that 64bit div, which can take nearly a hundred cycles (seriously, up to 90 cycles on Nehalem for example).




回答2:


You could do it the old-fashioned way with shift/add/subtract. The below code assumes a < n and
n < 263 (so things don't overflow):

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1)
            if ((rv += a) >= n) rv -= n;
        if ((a += a) >= n) a -= n;
        b >>= 1; }
    return rv;
}

You could use while (a && b) for the loop instead to short-circuit things if it's likely that a will be a factor of n. Will be slightly slower (more comparisons and likely correctly predicted branches) if a is not a factor of n.

If you really, absolutely, need that last bit (allowing n up to 264-1), you can use:

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1) {
            rv += a;
            if (rv < a || rv >= n) rv -= n; }
        uint64_t t = a;
        a += a;
        if (a < t || a >= n) a -= n;
        b >>= 1; }
    return rv;
}

Alternately, just use GCC instrinsics to access the underlying x64 instructions:

inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv;
    asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b));
    asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n));
    return rv;
}

The 64-bit div instruction is really slow, however, so the loop might actually be faster. You'd need to profile to be sure.




回答3:


This intrinsic is named __mul128.

typedef unsigned long long BIG;

// handles only the "hard" case when high bit of n is set
BIG shl_mod( BIG v, BIG n, int by )
{
    if (v > n) v -= n;
    while (by--) {
        if (v > (n-v))
            v -= n-v;
        else
            v <<= 1;
    }
    return v;
}

Now you can use shl_mod(B, n, 64)




回答4:


Having no inline assembly kind of sucks. Anyway, the function call overhead is actually extremely small. Parameters are passed in volatile registers and no cleanup is needed.

I don't have an assembler, and x64 targets don't support __asm, so I had no choice but to "assemble" my function from opcodes myself.

Obviously it depends on . I'm using mpir (gmp) as a reference to show the function produces correct results.


#include "stdafx.h"

// mulmod64(a, b, m) == (a * b) % m
typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);

uint8_t mulmod64_opcodes[] = {
    0x48, 0x89, 0xC8, // mov rax, rcx
    0x48, 0xF7, 0xE2, // mul rdx
    0x4C, 0x89, 0xC1, // mov rcx, r8
    0x48, 0xF7, 0xF1, // div rcx
    0x48, 0x89, 0xD0, // mov rax,rdx
    0xC3              // ret
};

mulmod64_fnptr_t mulmod64_fnptr;

void init() {
    DWORD dwOldProtect;
    VirtualProtect(
        &mulmod64_opcodes,
        sizeof(mulmod64_opcodes),
        PAGE_EXECUTE_READWRITE,
        &dwOldProtect);
    // NOTE: reinterpret byte array as a function pointer
    mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
}

int main() {
    init();

    uint64_t a64 = 2139018971924123ull;
    uint64_t b64 = 1239485798578921ull;
    uint64_t m64 = 8975489368910167ull;

    // reference code
    mpz_t a, b, c, m, r;
    mpz_inits(a, b, c, m, r, NULL);
    mpz_set_ui(a, a64);
    mpz_set_ui(b, b64);
    mpz_set_ui(m, m64);
    mpz_mul(c, a, b);
    mpz_mod(r, c, m);

    gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r);

    // using mulmod64
    uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
    printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64);
    return 0;
}



来源:https://stackoverflow.com/questions/20572031/compute-abm-fast-for-64-bit-unsigned-arguments-in-c-on-x86-64-platforms

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!