Repeated integer division by a runtime constant value

前端 未结 3 1670
暗喜
暗喜 2020-12-08 15:23

At some point in my program I compute an integer divisor d. From that point onward d is going to be constant.

Later in the code I will divi

相关标签:
3条回答
  • 2020-12-08 15:52

    The book "Hacker's delight" has "Chapter 10: Integer division by constant" spanning 74 pages. You can find all the code examples for free in this directory: http://www.hackersdelight.org/hdcode.htm In your case, Figs. 10-1., 10-2 and 10-3 are what you want.

    The problem of dividing by a constant d is equivalent to mutiplying by c = 1/d. These algorithms calculate such a constant for you. Once you have c, you calculate the dividend as such:

    int divideByMyConstant(int dividend){
      int c = MAGIC; // Given by the algorithm
    
      // since 1/d < 1, c is actually (1<<k)/d to fit nicely ina 32 bit int
      int k = MAGIC_SHIFT; //Also given by the algorithm
    
      long long tmp = (long long)dividend * c; // use 64 bit to hold all the precision...
    
      tmp >>= k; // Manual floating point number =)
    
      return (int)tmp;
    }
    
    0 讨论(0)
  • 2020-12-08 15:56

    update - in my original answer, I noted an algorithm mentioned in a prior thread for compiler generated code for divide by constant. The assembly code was written to match a document linked to from that prior thread. The compiler generated code involves slightly different sequences depending on the divisor.

    In this situation, the divisor is not known until runtime, so a common algorithm is desired. The example in geza's answer shows a common algorithm, which could be inlined in assembly code with GCC, but Visual Studio doesn't support inline assembly in 64 bit mode. In the case of Visual Studio, there's a trade off between the extra code involved if using intrinsics, versus calling a function written in assembly. On my system (Intel 3770k 3.5ghz) I tried calling a single function that does |mul add adc shr|, and I also tried using a pointer to function to optionally use shorter sequences |mul shr| or |shr(1) mul shr| depending on the divisor, but this provided little or no gain, depending on the compiler. The main overhead in this case is the call (versus |mul add adc shr| ). Even with the call overhead, the sequence|call mul add adc shr ret| averaged about 4 times as fast as divide on my system.

    Note that the linked to source code for libdivide in geza's answer does not have a common routine that can handle a divisor == 1. The libdivide common sequence is multiply, subtract, shift(1), add, shift, versus geza's example c++ sequence of multiply, add, adc, shift.


    My original answer: the example code below uses the algorithm described in a prior thread.

    Why does GCC use multiplication by a strange number in implementing integer division?

    This is a link to the document mentioned in the other thread:

    http://gmplib.org/~tege/divcnst-pldi94.pdf

    The example code below is based on the pdf document and is meant for Visual Studio, using ml64 (64 bit assembler), running on Windows (64 bit OS). The code with labels main... and dcm... is the code to generate a preshift (rspre, number of trailing zero bits in divisor), multiplier, and postshift (rspost). The code with labels dct... is the code to test the method.

            includelib      msvcrtd
            includelib      oldnames
    
    sw      equ     8                       ;size of word
    
            .data
    arrd    dq      1                       ;array of test divisors
            dq      2
            dq      3
            dq      4
            dq      5
            dq      7
            dq      256
            dq      3*256
            dq      7*256
            dq      67116375
            dq      07fffffffffffffffh      ;max divisor
            dq      0
            .data?
    
            .code
            PUBLIC  main
    
    main    PROC
            push    rbp
            push    rdi
            push    rsi
            sub     rsp,64                  ;allocate stack space
            mov     rbp,rsp
            lea     rsi,arrd                ;set ptr to array of divisors
            mov     [rbp+6*sw],rsi
            jmp     main1
    
    main0:  mov     [rbp+0*sw],rbx          ;[rbp+0*sw] = rbx = divisor = d
            cmp     rbx,1                   ;if d <= 1, q=n or overflow
            jbe     main1
            bsf     rcx,rbx                 ;rcx = rspre
            mov     [rbp+1*sw],rcx          ;[rbp+1*sw] = rspre
            shr     rbx,cl                  ;rbx = d>>rsc
            bsr     rcx,rbx                 ;rcx = floor(log2(rbx))
            mov     rsi,1                   ;rsi = 2^floor(log2(rbx))
            shl     rsi,cl
            cmp     rsi,rbx                 ;br if power of 2
            je      dcm03
            inc     rcx                     ;rcx = ceil(log2(rcx)) = L = rspost
            shl     rsi,1                   ;rsi = 2^L
    ;       jz      main1                   ;d > 08000000000000000h, use compare
            add     rcx,[rbp+1*sw]          ;rcx = L+rspre
            cmp     rcx,8*sw                ;if d > 08000000000000000h, use compare
            jae     main1
            mov     rax,1                   ;[rbp+3*sw] = 2^(L+rspre)
            shl     rax,cl
            mov     [rbp+3*sw],rax
            sub     rcx,[rbp+1*sw]          ;rcx = L
            xor     rdx,rdx
            mov     rax,rsi                 ;hi N bits of 2^(N+L)
            div     rbx                     ;rax == 1
            xor     rax,rax                 ;lo N bits of 2^(N+L)
            div     rbx
            mov     rdi,rax                 ;rdi = mlo % 2^N
            xor     rdx,rdx
            mov     rax,rsi                 ;hi N bits of 2^(N+L) + 2^(L+rspre)
            div     rbx                     ;rax == 1
            mov     rax,[rbp+3*sw]          ;lo N bits of 2^(N+L) + 2^(L+rspre)
            div     rbx                     ;rax = mhi % 2^N
            mov     rdx,rdi                 ;rdx = mlo % 2^N
            mov     rbx,8*sw                ;rbx = e = # bits in word
    dcm00:  mov     rsi,rdx                 ;rsi = mlo/2
            shr     rsi,1
            mov     rdi,rax                 ;rdi = mhi/2
            shr     rdi,1
            cmp     rsi,rdi                 ;break if mlo >= mhi
            jae     short dcm01
            mov     rdx,rsi                 ;rdx = mlo/2
            mov     rax,rdi                 ;rax = mhi/2
            dec     rbx                     ;e -= 1
            loop    dcm00                   ;loop if --shpost != 0
    dcm01:  mov     [rbp+2*sw],rcx          ;[rbp+2*sw] = shpost
            cmp     rbx,8*sw                ;br if N+1 bit multiplier
            je      short dcm02
            xor     rdx,rdx
            mov     rdi,1                   ;rax = m = mhi + 2^e
            mov     rcx,rbx
            shl     rdi,cl
            or      rax,rdi
            jmp     short dct00
    
    dcm02:  mov     rdx,1                   ;rdx = 2^N
            dec     qword ptr [rbp+2*sw]    ;dec rspost
            jmp     short dct00
    
    dcm03:  mov     rcx,[rbp+1*sw]          ;rcx = rsc
            jmp     short dct10
    
    ;       test    rbx = n, rdx = m bit N, rax = m%(2^N)
    ;               [rbp+1*sw] = rspre, [rbp+2*sw] = rspost
    
    dct00:  mov     rdi,rdx                 ;rdi:rsi = m
            mov     rsi,rax
            mov     rbx,0fffffffff0000000h  ;[rbp+5*sw] = rbx = n
    dct01:  mov     [rbp+5*sw],rbx
            mov     rdx,rdi                 ;rdx:rax = m
            mov     rax,rsi
            mov     rcx,[rbp+1*sw]          ;rbx = n>>rspre
            shr     rbx,cl
            or      rdx,rdx                 ;br if 65 bit m
            jnz     short dct02
            mul     rbx                     ;rdx = (n*m)>>N
            jmp     short dct03
    
    dct02:  mul     rbx
            sub     rbx,rdx
            shr     rbx,1
            add     rdx,rbx
    dct03:  mov     rcx,[rbp+2*sw]          ;rcx = rspost
            shr     rdx,cl                  ;rdx = q = quotient
            mov     [rbp+4*sw],rdx          ;[rbp+4*sw] = q
            xor     rdx,rdx                 ;rdx:rax = n
            mov     rax,[rbp+5*sw]
            mov     rbx,[rbp+0*sw]          ;rbx = d
            div     rbx                     ;rax = n/d
            mov     rdx,[rbp+4*sw]          ;br if ok
            cmp     rax,rdx                 ;br if ok
            je      short dct04
            nop                             ;debug check
    dct04:  mov     rbx,[rbp+5*sw]
            inc     rbx
            jnz     short dct01
            jmp     short main1
    
    ;       test    rbx = n, rcx = rsc
    dct10:  mov     rbx,0fffffffff0000000h  ;rbx = n
    dct11:  mov     rsi,rbx                 ;rsi = n
            shr     rsi,cl                  ;rsi = n>>rsc
            xor     edx,edx
            mov     rax,rbx
            mov     rdi,[rbp+0*sw]          ;rdi = d
            div     rdi
            cmp     rax,rsi                 ;br if ok
            je      short dct12
            nop
    dct12:  inc     rbx
            jnz     short dct11
    
    main1:  mov     rsi,[rbp+6*sw]          ;rsi ptr to divisor
            mov     rbx,[rsi]               ;rbx = divisor = d
            add     rsi,1*sw                ;advance ptr
            mov     [rbp+6*sw],rsi
            or      rbx,rbx
            jnz     main0                   ;br if not end table
    
            add     rsp,64                  ;restore regs
            pop     rsi
            pop     rdi
            pop     rbp
            xor     rax,rax
            ret     0
    
    main    ENDP
            END
    
    0 讨论(0)
  • 2020-12-08 16:01

    There is a library for this—libdivide:

    libdivide is an open source library for optimizing integer division

    libdivide allows you to replace expensive integer divides with comparatively cheap multiplication and bitshifts. Compilers usually do this, but only when the divisor is known at compile time. libdivide allows you to take advantage of it at runtime. The result is that integer division can become faster - a lot faster. Furthermore, libdivide allows you to divide an SSE2 vector by a runtime constant, which is especially nice because SSE2 has no integer division instructions!

    libdivide is free and open source with a permissive license. The name "libdivide" is a bit of a joke, as there is no library per se: the code is packaged entirely as a single header file, with both a C and a C++ API.

    You can read about the algorithm behind it at this blog; for example, this entry.

    Basically, the algorithm behind it is the same one that compilers use to optimize division by a constant, except that it allows these strength-reduction optimizations to be done at run-time.

    Note: you can create an even faster version of libdivide. The idea is that for every divisor, you can always create a triplet (mul/add/shift), so this expression gives the result: (num*mul+add)>>shift (multiply is a wide multiply here). Interestingly, this method could beat the compiler version for constant division for several microbenchmarks!


    Here's my implementation (this is not compilable out of the box, but the general algorithm can be seen):

    struct Divider_u32 {
        u32 mul;
        u32 add;
        s32 shift;
    
        void set(u32 divider);
    };
    
    void Divider_u32::set(u32 divider) {
        s32 l = indexOfMostSignificantBit(divider);
        if (divider&(divider-1)) {
            u64 m = static_cast<u64>(1)<<(l+32);
            mul = static_cast<u32>(m/divider);
    
            u32 rem = static_cast<u32>(m)-mul*divider;
            u32 e = divider-rem;
    
            if (e<static_cast<u32>(1)<<l) {
                mul++;
                add = 0;
            } else {
                add = mul;
            }
            shift = l;
        } else {
            if (divider==1) {
                mul = 0xffffffff;
                add = 0xffffffff;
                shift = 0;
            } else {
                mul = 0x80000000;
                add = 0;
                shift = l-1;
            }
        }
    }
    
    u32 operator/(u32 v, const Divider_u32 &div) {
        u32 t = static_cast<u32>((static_cast<u64>(v)*div.mul+div.add)>>32)>>div.shift;
    
        return t;
    }
    
    0 讨论(0)
提交回复
热议问题