Repeated integer division by a runtime constant value

前端 未结 3 1697
暗喜
暗喜 2020-12-08 15:23

At some point in my program I compute an integer divisor d. From that point onward d is going to be constant.

Later in the code I will divi

3条回答
  •  刺人心
    刺人心 (楼主)
    2020-12-08 15:56

    update - in my original answer, I noted an algorithm mentioned in a prior thread for compiler generated code for divide by constant. The assembly code was written to match a document linked to from that prior thread. The compiler generated code involves slightly different sequences depending on the divisor.

    In this situation, the divisor is not known until runtime, so a common algorithm is desired. The example in geza's answer shows a common algorithm, which could be inlined in assembly code with GCC, but Visual Studio doesn't support inline assembly in 64 bit mode. In the case of Visual Studio, there's a trade off between the extra code involved if using intrinsics, versus calling a function written in assembly. On my system (Intel 3770k 3.5ghz) I tried calling a single function that does |mul add adc shr|, and I also tried using a pointer to function to optionally use shorter sequences |mul shr| or |shr(1) mul shr| depending on the divisor, but this provided little or no gain, depending on the compiler. The main overhead in this case is the call (versus |mul add adc shr| ). Even with the call overhead, the sequence|call mul add adc shr ret| averaged about 4 times as fast as divide on my system.

    Note that the linked to source code for libdivide in geza's answer does not have a common routine that can handle a divisor == 1. The libdivide common sequence is multiply, subtract, shift(1), add, shift, versus geza's example c++ sequence of multiply, add, adc, shift.


    My original answer: the example code below uses the algorithm described in a prior thread.

    Why does GCC use multiplication by a strange number in implementing integer division?

    This is a link to the document mentioned in the other thread:

    http://gmplib.org/~tege/divcnst-pldi94.pdf

    The example code below is based on the pdf document and is meant for Visual Studio, using ml64 (64 bit assembler), running on Windows (64 bit OS). The code with labels main... and dcm... is the code to generate a preshift (rspre, number of trailing zero bits in divisor), multiplier, and postshift (rspost). The code with labels dct... is the code to test the method.

            includelib      msvcrtd
            includelib      oldnames
    
    sw      equ     8                       ;size of word
    
            .data
    arrd    dq      1                       ;array of test divisors
            dq      2
            dq      3
            dq      4
            dq      5
            dq      7
            dq      256
            dq      3*256
            dq      7*256
            dq      67116375
            dq      07fffffffffffffffh      ;max divisor
            dq      0
            .data?
    
            .code
            PUBLIC  main
    
    main    PROC
            push    rbp
            push    rdi
            push    rsi
            sub     rsp,64                  ;allocate stack space
            mov     rbp,rsp
            lea     rsi,arrd                ;set ptr to array of divisors
            mov     [rbp+6*sw],rsi
            jmp     main1
    
    main0:  mov     [rbp+0*sw],rbx          ;[rbp+0*sw] = rbx = divisor = d
            cmp     rbx,1                   ;if d <= 1, q=n or overflow
            jbe     main1
            bsf     rcx,rbx                 ;rcx = rspre
            mov     [rbp+1*sw],rcx          ;[rbp+1*sw] = rspre
            shr     rbx,cl                  ;rbx = d>>rsc
            bsr     rcx,rbx                 ;rcx = floor(log2(rbx))
            mov     rsi,1                   ;rsi = 2^floor(log2(rbx))
            shl     rsi,cl
            cmp     rsi,rbx                 ;br if power of 2
            je      dcm03
            inc     rcx                     ;rcx = ceil(log2(rcx)) = L = rspost
            shl     rsi,1                   ;rsi = 2^L
    ;       jz      main1                   ;d > 08000000000000000h, use compare
            add     rcx,[rbp+1*sw]          ;rcx = L+rspre
            cmp     rcx,8*sw                ;if d > 08000000000000000h, use compare
            jae     main1
            mov     rax,1                   ;[rbp+3*sw] = 2^(L+rspre)
            shl     rax,cl
            mov     [rbp+3*sw],rax
            sub     rcx,[rbp+1*sw]          ;rcx = L
            xor     rdx,rdx
            mov     rax,rsi                 ;hi N bits of 2^(N+L)
            div     rbx                     ;rax == 1
            xor     rax,rax                 ;lo N bits of 2^(N+L)
            div     rbx
            mov     rdi,rax                 ;rdi = mlo % 2^N
            xor     rdx,rdx
            mov     rax,rsi                 ;hi N bits of 2^(N+L) + 2^(L+rspre)
            div     rbx                     ;rax == 1
            mov     rax,[rbp+3*sw]          ;lo N bits of 2^(N+L) + 2^(L+rspre)
            div     rbx                     ;rax = mhi % 2^N
            mov     rdx,rdi                 ;rdx = mlo % 2^N
            mov     rbx,8*sw                ;rbx = e = # bits in word
    dcm00:  mov     rsi,rdx                 ;rsi = mlo/2
            shr     rsi,1
            mov     rdi,rax                 ;rdi = mhi/2
            shr     rdi,1
            cmp     rsi,rdi                 ;break if mlo >= mhi
            jae     short dcm01
            mov     rdx,rsi                 ;rdx = mlo/2
            mov     rax,rdi                 ;rax = mhi/2
            dec     rbx                     ;e -= 1
            loop    dcm00                   ;loop if --shpost != 0
    dcm01:  mov     [rbp+2*sw],rcx          ;[rbp+2*sw] = shpost
            cmp     rbx,8*sw                ;br if N+1 bit multiplier
            je      short dcm02
            xor     rdx,rdx
            mov     rdi,1                   ;rax = m = mhi + 2^e
            mov     rcx,rbx
            shl     rdi,cl
            or      rax,rdi
            jmp     short dct00
    
    dcm02:  mov     rdx,1                   ;rdx = 2^N
            dec     qword ptr [rbp+2*sw]    ;dec rspost
            jmp     short dct00
    
    dcm03:  mov     rcx,[rbp+1*sw]          ;rcx = rsc
            jmp     short dct10
    
    ;       test    rbx = n, rdx = m bit N, rax = m%(2^N)
    ;               [rbp+1*sw] = rspre, [rbp+2*sw] = rspost
    
    dct00:  mov     rdi,rdx                 ;rdi:rsi = m
            mov     rsi,rax
            mov     rbx,0fffffffff0000000h  ;[rbp+5*sw] = rbx = n
    dct01:  mov     [rbp+5*sw],rbx
            mov     rdx,rdi                 ;rdx:rax = m
            mov     rax,rsi
            mov     rcx,[rbp+1*sw]          ;rbx = n>>rspre
            shr     rbx,cl
            or      rdx,rdx                 ;br if 65 bit m
            jnz     short dct02
            mul     rbx                     ;rdx = (n*m)>>N
            jmp     short dct03
    
    dct02:  mul     rbx
            sub     rbx,rdx
            shr     rbx,1
            add     rdx,rbx
    dct03:  mov     rcx,[rbp+2*sw]          ;rcx = rspost
            shr     rdx,cl                  ;rdx = q = quotient
            mov     [rbp+4*sw],rdx          ;[rbp+4*sw] = q
            xor     rdx,rdx                 ;rdx:rax = n
            mov     rax,[rbp+5*sw]
            mov     rbx,[rbp+0*sw]          ;rbx = d
            div     rbx                     ;rax = n/d
            mov     rdx,[rbp+4*sw]          ;br if ok
            cmp     rax,rdx                 ;br if ok
            je      short dct04
            nop                             ;debug check
    dct04:  mov     rbx,[rbp+5*sw]
            inc     rbx
            jnz     short dct01
            jmp     short main1
    
    ;       test    rbx = n, rcx = rsc
    dct10:  mov     rbx,0fffffffff0000000h  ;rbx = n
    dct11:  mov     rsi,rbx                 ;rsi = n
            shr     rsi,cl                  ;rsi = n>>rsc
            xor     edx,edx
            mov     rax,rbx
            mov     rdi,[rbp+0*sw]          ;rdi = d
            div     rdi
            cmp     rax,rsi                 ;br if ok
            je      short dct12
            nop
    dct12:  inc     rbx
            jnz     short dct11
    
    main1:  mov     rsi,[rbp+6*sw]          ;rsi ptr to divisor
            mov     rbx,[rsi]               ;rbx = divisor = d
            add     rsi,1*sw                ;advance ptr
            mov     [rbp+6*sw],rsi
            or      rbx,rbx
            jnz     main0                   ;br if not end table
    
            add     rsp,64                  ;restore regs
            pop     rsi
            pop     rdi
            pop     rbp
            xor     rax,rax
            ret     0
    
    main    ENDP
            END
    

提交回复
热议问题