optimise “binary_fold” algorithm and make it left (or right) associative

天大地大妈咪最大 提交于 2020-01-05 23:33:03

问题


Following my original question and considering some of the proposed solutions I came up with this for C++14:

#include <algorithm>
#include <exception>
#include <iterator>
#include <cstddef>

template<class It, class Func>
auto binary_fold(It begin, It end, Func op) ->  decltype(op(*begin, *end)) {
  std::ptrdiff_t diff = end - begin;
  switch (diff) {
    case 0: throw std::out_of_range("binary fold on empty container");
    case 1: return *begin;
    case 2: return op(*begin, *(begin + 1));
    default: { // first round to the nearest multiple of 2 and then advance
      It mid{begin};
      int div = diff/2;
      int offset = (div%2 == 1) ? (div+1) : div; // round to the closest multiple of two (upwards)
      std::advance(mid, offset);
      return op( binary_fold(begin,mid,op), binary_fold(mid,end,op) );
    }
  }
}

this algorithm will perform a binary operation pairwise recursively until a result is obtained. E.g.

 std::vector<int> v = {1,3,5,6,1};
 auto result = mar::binary_fold(v.cbegin(), v.cend(), std::minus<int>());

will resolve in:

1 - (5-6) - (1-3) = 0

In some cases (like the one above) the algorithm will be left associative, but in others (like the following), it will be right associative:

  std::vector<int> v = {7,4,9,2,6,8};
  auto result = mar::binary_fold(v.cbegin(), v.cend(), std::minus<int>());

results in:

(7-4) - (9-2) - (6-8) = -2

I'm wondering how I can further optimise this algorithm so that:

a. it is definitely left or right associative

b. it is as fast as possible (this will be put within an openGL drawing loop, so it has to be very fast).

c. make a TMP version that will compute the offsets in compilation time when the size of the container is known (this is not necessary for my application, but I'm just curious of how it can be done).

my first thoughts on b. is that an iterative version would be probably faster, and that the offset calculation could be further optimised (maybe with some bitwise magic?). I'm stuck nevertheless.


回答1:


I have two TMP versions. Which one is better, depends on the data types, I guess:

Solution A:

First, let's find a good offset for the split point (powers of two seem great):

template<std::ptrdiff_t diff, std::ptrdiff_t V = 2>
struct offset
{
  static constexpr std::ptrdiff_t value =
      (V * 2 < diff - 1) ? offset<diff, V * 2>::value : V;
};

// End recursion
template<std::ptrdiff_t diff>
struct offset<diff, 1<<16>
{
  static constexpr std::ptrdiff_t value = 1<<16;
};

// Some special cases
template<> 
struct offset<0, 2>
{
  static constexpr std::ptrdiff_t value = 0;
};

template<>
struct offset<1, 2> 
{
  static constexpr std::ptrdiff_t value = 0;
};

template<>
struct offset<2, 2>
{
  static constexpr std::ptrdiff_t value = 0;
};

With this, we can create a recursive TMP version:

template <std::ptrdiff_t diff, class It, class Func>
auto binary_fold_tmp(It begin, It end, Func op)
    -> decltype(op(*begin, *end))
{
  assert(end - begin == diff);
  switch (diff)
  {
    case 0:
      assert(false);
      return 0;  // This will never happen
    case 1:
      return *begin;
    case 2:
      return op(*begin, *(begin + 1));
    default:
    {  // first round to the nearest multiple of 2 and then advance
      It mid{begin};
      std::advance(mid, offset<diff>::value);
      auto left = binary_fold_tmp<offset<diff>::value>(begin, mid, op);
      auto right =
          binary_fold_tmp<diff - offset<diff>::value>(mid, end, op);
      return op(left, right);
    }
  }
}

This can be combined with a non-TMP version like this, for instance:

template <class It, class Func>
auto binary_fold(It begin, It end, Func op)
    -> decltype(op(*begin, *end))
{
  const auto diff = end - begin;
  assert(diff > 0);
  switch (diff)
  {
    case 1:
      return binary_fold_tmp<1>(begin, end, op);
    case 2:
      return binary_fold_tmp<2>(begin, end, op);
    case 3:
      return binary_fold_tmp<3>(begin, end, op);
    case 4:
      return binary_fold_tmp<4>(begin, end, op);
    case 5:
      return binary_fold_tmp<5>(begin, end, op);
    case 6:
      return binary_fold_tmp<6>(begin, end, op);
    case 7:
      return binary_fold_tmp<7>(begin, end, op);
    case 8:
      return binary_fold_tmp<8>(begin, end, op);
    default:
      if (diff < 16)
        return op(binary_fold_tmp<8>(begin, begin + 8, op),
                  binary_fold(begin + 8, end, op));
      else if (diff < 32)
        return op(binary_fold_tmp<16>(begin, begin + 16, op),
                  binary_fold(begin + 16, end, op));
      else
        return op(binary_fold_tmp<32>(begin, begin + 32, op),
                  binary_fold(begin + 32, end, op));
  }
}

Solution B:

This calculates the pair-wise results, stores them in a buffer, and then calls itself with the buffer:

template <std::ptrdiff_t diff, class It, class Func, size_t... Is>
auto binary_fold_pairs_impl(It begin,
                            It end,
                            Func op,
                            const std::index_sequence<Is...>&)
    -> decltype(op(*begin, *end))
{
  std::decay_t<decltype(*begin)> pairs[diff / 2] = {
      op(*(begin + 2 * Is), *(begin + 2 * Is + 1))...};

  if (diff == 2)
    return pairs[0];
  else
    return binary_fold_pairs_impl<diff / 2>(
        &pairs[0],
        &pairs[0] + diff / 2,
        op,
        std::make_index_sequence<diff / 4>{});
}

template <std::ptrdiff_t diff, class It, class Func>
auto binary_fold_pairs(It begin, It end, Func op) -> decltype(op(*begin, *end))
{
  return binary_fold_pairs_impl<diff>(
      begin, end, op, std::make_index_sequence<diff / 2>{});
}

This template function requires diff to be a power of 2. But of course you can combine it with a non-template version, again:

template <class It, class Func>
auto binary_fold_mix(It begin, It end, Func op) -> decltype(op(*begin, *end))
{
  const auto diff = end - begin;
  assert(diff > 0);
  switch (diff)
  {
    case 1:
      return *begin;
    case 2:
      return binary_fold_pairs<2>(begin, end, op);
    case 3:
      return op(binary_fold_pairs<2>(begin, begin + 1, op),
                *(begin + (diff - 1)));
    case 4:
      return binary_fold_pairs<4>(begin, end, op);
    case 5:
      return op(binary_fold_pairs<4>(begin, begin + 4, op),
                *(begin + (diff - 1)));
    case 6:
      return op(binary_fold_pairs<4>(begin, begin + 4, op),
                binary_fold_pairs<4>(begin + 4, begin + 6, op));
    case 7:
      return op(binary_fold_pairs<4>(begin, begin + 4, op),
                binary_fold_mix(begin + 4, begin + 7, op));
    case 8:
      return binary_fold_pairs<8>(begin, end, op);
    default:
      if (diff <= 16)
        return op(binary_fold_pairs<8>(begin, begin + 8, op),
                  binary_fold_mix(begin + 8, end, op));
      else if (diff <= 32)
        return op(binary_fold_pairs<16>(begin, begin + 16, op),
                  binary_fold_mix(begin + 16, end, op));
      else
        return op(binary_fold_pairs<32>(begin, begin + 32, op),
                  binary_fold_mix(begin + 32, end, op));
  }
}

I measured with the same program as MtRoad. On my machine, the differences are not as big as MtRoad reported. With -O3 solutions A and B seem to be slightly faster than MtRoad's version, but in reality, you need to test with your types and data.

Remark: I did not test my versions too rigorously.




回答2:


I wrote an "always-left associative" iterative version, with some timing runs you can use as well. It performs slightly worse until you turn on compiler optimizations.

Total run times for 10000 iterations, with 5000 values.

g++ --std=c++11 main.cpp && ./a.out
Recursive elapsed:9642msec
Iterative elapsed:10189msec

$ g++ --std=c++11 -O1 main.cpp && ./a.out
Recursive elapsed:3468msec
Iterative elapsed:3098msec
Iterative elapsed:3359msec # another run
Recursive elapsed:3668msec

$ g++ --std=c++11 -O2 main.cpp && ./a.out
Recursive elapsed:3193msec
Iterative elapsed:2763msec
Recursive elapsed:3184msec # another run
Iterative elapsed:2696msec

$ g++ --std=c++11 -O3 main.cpp && ./a.out
Recursive elapsed:3183msec
Iterative elapsed:2685msec
Recursive elapsed:3106msec # another run
Iterative elapsed:2681msec
Recursive elapsed:3054msec # another run
Iterative elapsed:2653msec

Compilers can have a little easier time optimizing loops than recursion.

#include <algorithm>
#include <functional>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>    

template<class It, class Func>
auto binary_fold_rec(It begin, It end, Func op) ->  decltype(op(*begin, *end)) {
  std::ptrdiff_t diff = end - begin;
  switch (diff) {
    case 0: throw std::out_of_range("binary fold on empty container");
    case 1: return *begin;
    case 2: return op(*begin, *(begin + 1));
    default: { // first round to the nearest multiple of 2 and then advance
      It mid{begin};
      int div = diff/2;
      int offset = (div%2 == 1) ? (div+1) : div; // round to the closest multiple of two (upwards)
      std::advance(mid, offset);
      return op( binary_fold_rec(begin,mid,op), binary_fold_rec(mid,end,op) );
    }
  }
}


// left-associative
template<class It, class Func>
auto binary_fold_it(It begin, It end, Func op) -> decltype(op(*begin, *end)) {
    // Allocates enough scratch to begin with that we don't need to mess with again.
    std::ptrdiff_t diff = end - begin;
    std::vector<decltype(op(*begin, *end))> scratch (static_cast<int>(diff));
    auto scratch_current = scratch.begin();

    if(diff == 0) {
        throw std::out_of_range("binary fold on empty container.");
    }

    while(diff > 1) {
        auto fake_end = (diff & 1) ? end - 1 : end; 
        while(begin != fake_end) {
            (*scratch_current++) = op(*begin, *(begin+1));
            begin += 2; // silly c++ can't guarantee ++ order, so increment here.
        }
        if(fake_end != end) {
            *scratch_current++ = *begin;
        }
        end = scratch_current;
        begin = scratch_current = scratch.begin();
        diff = end - begin;
    };
    return scratch[0];
}


void run(std::initializer_list<int> elems, int expected) {
    std::vector<int> v(elems);
    auto result = binary_fold_it(v.begin(), v.end(), std::minus<int>());
    std::cout << result << std::endl;
    assert(binary_fold_it(v.begin(), v.end(), std::minus<int>()) == expected);
}

constexpr int rolls = 10000;
constexpr int min_val = -1000;
constexpr int max_val = 1000;
constexpr int num_vals = 5000;

std::vector<int> random_vector() {
    // Thanks http://stackoverflow.com/questions/21516575/fill-a-vector-with-random-numbers-c
    // for saving me time.
    std::uniform_int_distribution<int> distribution(min_val, max_val);
    std::default_random_engine generator;
    std::vector<int> data(num_vals);
    std::generate(data.begin(), data.end(), [&]() { return distribution(generator); });
    return data;
}

template<typename It, typename Func>
void evaluate(void(*func)(It, It, Func), const char* message) {
    auto start = std::chrono::high_resolution_clock::now();
    for(int i=0; i<rolls; i++) {
        auto data = random_vector();
        func(data.begin(), data.end(), std::minus<int>());
    }
    auto end = std::chrono::high_resolution_clock::now();
    std::cout << message << std::chrono::duration_cast<std::chrono::milliseconds>(end-start).count() << "msec\n";
}


void evaluate(void(*func)(), const char* message) {
    auto start = std::chrono::high_resolution_clock::now();
    for(int i=0; i<rolls; i++) {
        func();
    }
    auto end = std::chrono::high_resolution_clock::now();
    std::cout << message << std::chrono::duration_cast<std::chrono::milliseconds>(end-start).count() << "msec\n";
}


void time_it() {
    auto data = random_vector();
    binary_fold_it(data.begin(), data.end(), std::minus<int>());
}


void time_rec() {
    auto data = random_vector();
    binary_fold_rec(data.begin(), data.end(), std::minus<int>());
}


int main() {
    evaluate(time_rec, "Recursive elapsed:");
    evaluate(time_it, "Iterative elapsed:");
    return 0;
}


来源:https://stackoverflow.com/questions/35318908/optimise-binary-fold-algorithm-and-make-it-left-or-right-associative

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!