Pairwise operation on segmented data in CUDA/thrust

前端 未结 2 521
礼貌的吻别
礼貌的吻别 2021-01-03 14:36

Suppose I have

  • a data array,
  • an array containing keys referencing entries in the data array and
  • a third array which contains an id
2条回答
  •  半阙折子戏
    2021-01-03 15:12

    I've found this solution to your question. I used a mix of cuda kernels and thrust primitives. I suppose that your operator has the commutative property on the keys arguments, that is

    fun(key1, key2, id)  == fun(key2, key1, id)
    

    My solution produce three output arrays (keys1, keys2 and ids) in two steps. In the first step it computes only the number of elements in the output arrays and in the second step it fill the arrays. Basically the algorithm runs two times: in the first time it "simulates" the writing and the second time it actually writes the output. This is a common pattern that I use when the output size depends on the input data.

    Here are the steps:

    1. Sort the keys by segment. This step is the same of the algorithm of @Robert Crovella.
    2. count the number of pairs produced by each key.Each thread is associated a key. Each thread starts counting all the valid pairs from its base index to the end of the segment. The output of this step is in the vector d_cpairs
    3. compute the size of keys1, keys2 and ids and the offset of each pair int in those arrays.A sum-reduction on d_cpairs computes the size of the output arrays. An exlusive-scan operation, performed on d_cpairs, produces the positions of the pairs in the output arrays. The output of this step is in the vector d_addrs
    4. fill the keys1, keys2 and ids with the data. This step is exactly the same of step 3) but it actually writes the data in keys1, keys2 and ids.

    Here is the output of my algorithm:

        d_keys: 1 2 3 1 1 2 2 1 1 1 
        d_segs: 0 0 0 1 2 2 2 3 3 3 
        d_cpairs: 2 1 0 0 1 0 0 0 0 0 
        num_pairs: 4
        d_addrs: 0 2 3 3 3 4 4 4 4 4 
        keys1: 1 1 2 1 
        keys2: 2 3 3 2 
        ids: 0 0 0 2
    

    Note that the last column in the output is not exactly like you want but if the commutative property hold it's fine.

    Here is my complete code. Probably this is not the fastest solution but I think it's simple.

    #include 
    #include 
    #include 
    #include 
    #include 
    
    #define SHOW_VECTOR(V,size)  std::cout << #V << ": "; for(int i =0 ; i < size; i++ ){ std::cout << V[i] << " "; } std::cout << std::endl;
    #define RAW_CAST(V) thrust::raw_pointer_cast(V.data())
    
    __global__ 
    void count_kernel(const int *d_keys,const int *d_segs,int *d_cpairs, int siz){
        int tidx = threadIdx.x+ blockIdx.x*blockDim.x;
    
        if(tidx < siz){
            int sum = 0;
            int i = tidx+1; 
            while(d_segs[i] == d_segs[tidx]){        
                if(d_keys[i] != d_keys[tidx] &&
                   d_keys[i] != d_keys[i-1]){
                    sum++;
                }
                i++;
            }
            d_cpairs[tidx] = sum;
        }
    }
    
    __global__ 
    void scatter_kernel(const int *d_keys,
                        const int *d_segs,
                        const int *d_addrs, 
                        int *d_keys1,
                        int *d_keys2,
                        int *d_ids,                           
                        int siz){
        int tidx = threadIdx.x+ blockIdx.x*blockDim.x;
    
        if(tidx < siz){
            int base_address = d_addrs[tidx];
            int j =0;
            int i = tidx+1; 
            while(d_segs[i] == d_segs[tidx]){        
                if(d_keys[i] != d_keys[tidx] &&
                   d_keys[i] != d_keys[i-1]){
    
                   d_keys1[base_address+j] = d_keys[tidx];
                   d_keys2[base_address+j] = d_keys[i];
                   d_ids[base_address+j] = d_segs[i];                      
                   j++;
                }
                i++;
            }
        }
    }
    
    int main(){
    
        int keyArray[] = {1, 2, 3, 1, 2, 2, 1, 1, 1, 1};
        int idsArray[]      = {0, 0, 0, 1, 2, 2, 2, 3, 3, 3};
    
    
        int sz1 = sizeof(keyArray)/sizeof(keyArray[0]);
    
        thrust::host_vector h_keys(keyArray, keyArray+sz1);
        thrust::host_vector h_segs(idsArray, idsArray+sz1);
        thrust::device_vector d_keys = h_keys;
        thrust::device_vector d_segs = h_segs;
        thrust::device_vector d_cpairs(sz1);
        thrust::device_vector d_addrs(sz1);
    
        //sort each segment to group like keys together
        thrust::stable_sort_by_key(d_keys.begin(), d_keys.end(), d_segs.begin());
        thrust::stable_sort_by_key(d_segs.begin(), d_segs.end(), d_keys.begin());
    
        SHOW_VECTOR(d_keys,sz1);
        SHOW_VECTOR(d_segs,sz1);
    
        //count the number of pairs produced by each key
        count_kernel<<<1,sz1>>>(RAW_CAST(d_keys),RAW_CAST(d_segs),RAW_CAST(d_cpairs),sz1);
    
        SHOW_VECTOR(d_cpairs,sz1);
    
        //determine the total number of pairs
        int num_pairs  = thrust::reduce(d_cpairs.begin(),d_cpairs.end());
    
        std::cout << "num_pairs: " << num_pairs << std::endl;
        //compute the addresses     
        thrust::exclusive_scan(d_cpairs.begin(),d_cpairs.end(),d_addrs.begin());
    
    
        thrust::device_vector keys1(num_pairs);
        thrust::device_vector keys2(num_pairs);
        thrust::device_vector ids(num_pairs);
    
        SHOW_VECTOR(d_addrs,sz1);   
    
        //fill the vector with the keys and ids
        scatter_kernel<<<1,sz1>>>(RAW_CAST(d_keys),
                                  RAW_CAST(d_segs),
                                  RAW_CAST(d_addrs),
                                  RAW_CAST(keys1),
                                  RAW_CAST(keys2),
                                  RAW_CAST(ids),                              
                                  sz1);
    
        SHOW_VECTOR(keys1,num_pairs);
        SHOW_VECTOR(keys2,num_pairs);
        SHOW_VECTOR(ids,num_pairs);
    
    
        return 0;
    }
    

提交回复
热议问题