Combining two lists by key using Thrust

后端 未结 3 1338
广开言路
广开言路 2020-12-04 00:25

Given two key-value lists, I am trying to combine the two sides by matching the keys and applying a function to the two values when the keys match. In my case I want to mult

3条回答
  •  星月不相逢
    2020-12-04 00:53

    You can actually do all you want using one thrust::set_intersection_by_key call. However, some prerequisites need to be met:

    First, the easy one:

    You need to zip Lvalsv and Rvalsv into a single thrust::zip_iterator and pass this as the values to thrust::set_intersection_by_key.

    You could already run this:

    std::size_t min_size = std::min(Lsize, Rsize);
    thrust::device_vector result_keys(min_size);
    thrust::device_vector result_values_left(min_size);
    thrust::device_vector result_values_right(min_size);
    
    auto zipped_input_values = thrust::make_zip_iterator(thrust::make_tuple(Lvalsv.begin(), Rvalsv.begin()));
    auto zipped_output_values = thrust::make_zip_iterator(thrust::make_tuple(result_values_left.begin(), result_values_right.begin()));
    
    auto result_pair = thrust::set_intersection_by_key(Lkeysv.begin(), Lkeysv.end(), Rkeysv.begin(), Rkeysv.end(), zipped_input_values, result_keys.begin(), zipped_output_values);
    

    This would yield two result vectors, which you would need to multiply element-wise to get your final result.

    But wait, wouldn't it be great if you could avoid having to store these two vectors as the result, then read each element again for multiplying them and then store the final result in a third vector?

    Let's do that. The concept I adapted is from here. The transform_output_iterator is a iterator, which is a wrapper around another OutputIterator. When writing to the transform_output_iterator, a UnaryFunction is applied to the value to be written, then that result is written to the wrapped OutputIterator.

    This allows us to pass the result from thrust::set_intersection_by_key through the Multiplier functor and then store it in the results in a single result_values vector.

    The following code implements this idea:

    #include 
    #include 
    #include 
    
    #include 
    #include 
    #include 
    #include 
    #include   
    #include 
    #include 
    
    #define PRINTER(name) print(#name, (name))
    template