I want to write a simple scan over an array. I have a std::vector<int> data
and I want to find all array indices at which the elements are less than 9 and add them to a result vector. I can write this using a branch:
for (int i = 0; i < data.size(); ++i)
if (data[i] < 9)
r.push_back(i);
This gives the correct answer but I would like to compare it to a branchless version.
Using raw arrays - and assuming that data
is an int array, length
is the number of elements in it, and r
is a result array with plenty of room - I can write something like:
int current_write_point = 0;
for (int i = 0; i < length; ++i){
r[current_write_point] = i;
current_write_point += (data[i] < 9);
}
How would I get similar behavior using a vector for data
?
Let's see with the actual compiler output:
auto scan_branch(const std::vector<int>& v)
{
std::vector<int> res;
int insert_index = 0;
for(int i = 0; i < v.size(); ++i)
{
if (v[i] < 9)
{
res.push_back(i);
}
}
return res;
}
This code clearly has a branch at 26th line of disassembly. If it's greater than or equal to 9, it just continues with the next element, however in the event of lesser than 9, some horrible amount of code executes for the push_back and we continue. Nothing unexpected.
auto scan_nobranch(const std::vector<int>& v)
{
std::vector<int> res;
res.resize(v.size());
int insert_index = 0;
for(int i = 0; i < v.size(); ++i)
{
res[insert_index] = i;
insert_index += v[i] < 9;
}
res.resize(insert_index);
return res;
}
This one, however, only has a conditional move, which you can see in the 190th line of the disassembly. It looks like we have a winner. Since conditional move cannot result in pipeline stalls, there are no branches in this one (except the for condition check).
std::copy_if(std::begin(data), std::end(data), std::back_inserter(r));
Well, you could just resize the vector beforehand and keep your algorithm:
// Resize the vector so you can index it normally
r.resize(length);
// Do your algorithm like before
int current_write_point = 0;
for (int i = 0; i < length; ++i){
r[current_write_point] = i;
current_write_point += (data[i] < 9);
}
// Afterwards, current_write_point can be used to shrink the vector, so
// there are no excess elements not written to
r.resize(current_write_point + 1);
If you wanted no comparisons though, you can use some bitwise and boolean operations with short-circuiting to determine that.
First, we know that all negative integers are less than 9. Secondly, if it is positive, we can use the bitmask to determine if an integer is in the range 0-15 (actually, we'll check if it's NOT in that range, so greater than 15). Then, we know that if the result of subtracion of 8 from that number is negative, then the result is less than 9:
Actually, I just figured a better way. Since we can easily determine if x < 0
, we can just subtract x
by 9 to determine if x < 9
:
#include <iostream>
// Use bitwise operations to determine if x is negative
int n(int x) {
return x & (1 << 31);
}
int main() {
int current_write_point = 0;
for (int i = 0; i < length; ++i){
r[current_write_point] = i;
current_write_point += n(data[i] - 9);
}
}
来源:https://stackoverflow.com/questions/38798841/how-do-i-write-a-branchless-stdvector-scan