问题
In my application I have a class like this:
class sample{
thrust::device_vector<int> edge_ID;
thrust::device_vector<float> weight;
thrust::device_vector<int> layer_ID;
/*functions, zip_iterators etc. */
};
At a given index every vector stores the corresponding data of the same edge.
I want to write a function that filters out all the edges of a given layer, something like this:
void filter(const sample& src, sample& dest, const int& target_layer){
for(...){
if( src.layer_ID[x] == target_layer)/*copy values to dest*/;
}
}
The best way I've found to do this is by using thrust::copy_if(...)
(details)
It would look like this:
void filter(const sample& src, sample& dest, const int& target_layer){
thrust::copy_if(src.begin(),
src.end(),
dest.begin(),
comparing_functor() );
}
And this is where we reach my problem:
The comparing_functor()
is an unary function, which means I cant pass my target_layer
value to it.
Anyone knows how to get around this, or has an idea for implementing this while keeping the data structure of the class intact?
回答1:
You can pass specific values to functors for use in the predicate test in addition to the data that is ordinarily passed to them. Here's a worked example:
#include <iostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>
#define DSIZE 10
#define FVAL 5
struct test_functor
{
const int a;
test_functor(int _a) : a(_a) {}
__device__
bool operator()(const int& x ) {
return (x==a);
}
};
int main(){
int target_layer = FVAL;
thrust::host_vector<int> h_vals(DSIZE);
thrust::sequence(h_vals.begin(), h_vals.end());
thrust::device_vector<int> d_vals = h_vals;
thrust::device_vector<int> d_result(DSIZE);
thrust::copy_if(d_vals.begin(), d_vals.end(), d_result.begin(), test_functor(target_layer));
thrust::host_vector<int> h_result = d_result;
std::cout << "Data :" << std::endl;
thrust::copy(h_vals.begin(), h_vals.end(), std::ostream_iterator<int>( std::cout, " "));
std::cout << std::endl;
std::cout << "Filter Value: " << target_layer << std::endl;
std::cout << "Results :" << std::endl;
thrust::copy(h_result.begin(), h_result.end(), std::ostream_iterator<int>( std::cout, " "));
std::cout << std::endl;
return 0;
}
来源:https://stackoverflow.com/questions/17468745/thrust-filter-by-key-value