问题
I have some critical branching code inside a loop that's run about 2^26 times. Branch prediction is not optimal because m is random. How would I remove the branching, possibly using bitwise operators?
bool m;
unsigned int a;
const unsigned int k = ...; // k >= 7
if(a == 0)
a = (m ? (a+1) : (k));
else if(a == k)
a = (m ? 0 : (a-1));
else
a = (m ? (a+1) : (a-1));
And here is the relevant assembly generated by gcc -O3:
.cfi_startproc
movl 4(%esp), %edx
movb 8(%esp), %cl
movl (%edx), %eax
testl %eax, %eax
jne L15
cmpb $1, %cl
sbbl %eax, %eax
andl $638, %eax
incl %eax
movl %eax, (%edx)
ret
L15:
cmpl $639, %eax
je L23
testb %cl, %cl
jne L24
decl %eax
movl %eax, (%edx)
ret
L23:
cmpb $1, %cl
sbbl %eax, %eax
andl $638, %eax
movl %eax, (%edx)
ret
L24:
incl %eax
movl %eax, (%edx)
ret
.cfi_endproc
回答1:
The branch-free division-free modulo could have been useful, but testing shows that in practice, it isn't.
const unsigned int k = 639;
void f(bool m, unsigned int &a)
{
a += m * 2 - 1;
if (a == -1u)
a = k;
else if (a == k + 1)
a = 0;
}
Testcase:
unsigned a = 0;
f(false, a);
assert(a == 639);
f(false, a);
assert(a == 638);
f(true, a);
assert(a == 639);
f(true, a);
assert(a == 0);
f(true, a);
assert(a == 1);
f(false, a);
assert(a == 0);
Actually timing this, using a test program:
int main()
{
for (int i = 0; i != 10000; i++)
{
unsigned int a = k / 2;
while (a != 0) f(rand() & 1, a);
}
}
(Note: there's no srand, so results are deterministic.)
My original answer: 5.3s
The code in the question: 4.8s
Lookup table: 4.5s (static unsigned lookup[2][k+1];)
Lookup table: 4.3s (static unsigned lookup[k+1][2];)
Eric's answer: 4.2s
This version: 4.0s
回答2:
The fastest I've found is now the table implementation
Timings I got (UPDATED for new measurement code)
HVD's most recent: 9.2s
Table version: 7.4s (with k=693)
Table creation code:
unsigned int table[2*k];
table_ptr = table;
for(int i = 0; i < k; i++){
unsigned int a = i;
f(0, a);
table[i<<1] = a;
a = i;
f(1, a);
table[i<<1 + 1] = a;
}
Table runtime loop:
void f(bool m, unsigned int &a){
a = table_ptr[a<<1 | m];
}
With HVD's measurement code, I saw the cost of the rand() dominating the runtime, so that the runtime for a branchless version was about the same range as these solutions. I changed the measurement code to this (UPDATED to keep random branch order, and pre-computing random values to prevent rand(), etc. from trashing the cache)
int main(){
unsigned int a = k / 2;
int m[100000];
for(int i = 0; i < 100000; i++){
m[i] = rand() & 1;
}
for (int i = 0; i != 10000; i++
{
for(int j = 0; j != 100000; j++){
f(m[j], a);
}
}
}
回答3:
I don't think you can remove the branches entirely, but you can reduce the number by branching on m first.
if (m){
if (a==k) {a = 0;} else {++a;}
}
else {
if (a==0) {a = k;} else {--a;}
}
回答4:
Adding to Antimony's rewrite:
if (a==k) {a = 0;} else {++a;}
looks like an increase with wraparound. You can write this as
a=(a+1)%k;
which, of course, only makes sense if divisions are actually faster than branches.
Not sure about the other one; too lazy to think about what the (~0)%k will be.
回答5:
This has no branches. Because K is constant, compiler might be able to optimize the modulo depending on it's value. And if K is 'small' then a full lookup table solution would probably be even faster.
bool m;
unsigned int a;
const unsigned int k = ...; // k >= 7
const int inc[2] = {1, k};
a = a + inc[m] % (k+1);
回答6:
If k isn't large enough to cause overflow, you could do something like this:
int a; // Note: not unsigned int
int plusMinus = 2 * m - 1;
a += plusMinus;
if(a == -1)
a = k;
else if (a == k+1)
a = 0;
Still branches, but the branch prediction should be better, since the edge conditions are rarer than m-related conditions.
来源:https://stackoverflow.com/questions/12030022/branching-elimination-using-bitwise-operators