How do I write a branchless std::vector scan?

血红的双手。 提交于 2019-12-05 12:41:52

Let's see with the actual compiler output:

auto scan_branch(const std::vector<int>& v)
{
  std::vector<int> res;
  int insert_index = 0;
  for(int i = 0; i < v.size(); ++i)
  {
    if (v[i] < 9)
    {
       res.push_back(i);
    } 
  }
  return res;
}

This code clearly has a branch at 26th line of disassembly. If it's greater than or equal to 9, it just continues with the next element, however in the event of lesser than 9, some horrible amount of code executes for the push_back and we continue. Nothing unexpected.

auto scan_nobranch(const std::vector<int>& v)
{
  std::vector<int> res;
  res.resize(v.size());

  int insert_index = 0;
  for(int i = 0; i < v.size(); ++i)
  {
    res[insert_index] = i;
    insert_index += v[i] < 9;
  }

  res.resize(insert_index);
  return res;
}

This one, however, only has a conditional move, which you can see in the 190th line of the disassembly. It looks like we have a winner. Since conditional move cannot result in pipeline stalls, there are no branches in this one (except the for condition check).

std::copy_if(std::begin(data), std::end(data), std::back_inserter(r));

Well, you could just resize the vector beforehand and keep your algorithm:

// Resize the vector so you can index it normally
r.resize(length);

// Do your algorithm like before
int current_write_point = 0;
for (int i = 0; i < length; ++i){
    r[current_write_point] = i;
    current_write_point += (data[i] < 9);
}

// Afterwards, current_write_point can be used to shrink the vector, so
// there are no excess elements not written to
r.resize(current_write_point + 1);

If you wanted no comparisons though, you can use some bitwise and boolean operations with short-circuiting to determine that.

First, we know that all negative integers are less than 9. Secondly, if it is positive, we can use the bitmask to determine if an integer is in the range 0-15 (actually, we'll check if it's NOT in that range, so greater than 15). Then, we know that if the result of subtracion of 8 from that number is negative, then the result is less than 9: Actually, I just figured a better way. Since we can easily determine if x < 0, we can just subtract x by 9 to determine if x < 9:

#include <iostream>

// Use bitwise operations to determine if x is negative
int n(int x) {
    return x & (1 << 31);
}

int main() {
    int current_write_point = 0;
    for (int i = 0; i < length; ++i){
        r[current_write_point] = i;
        current_write_point += n(data[i] - 9);
    }
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!