Passing Host Function as a function pointer in __global__ OR __device__ function in CUDA

前端 未结 2 1769
迷失自我
迷失自我 2020-11-28 15:32

I am currently developing a GPU version of a CPU function (e.g. function Calc(int a, int b, double* c, souble* d, CalcInvFunction GetInv )), in which a host function is pass

2条回答
  •  庸人自扰
    2020-11-28 16:21

    Finally, i have been able to pass a host function as a function pointer in cuda kernel function (__global__ function). Thanks to Robert Crovella and njuffa for the answer. I have been able to pass a class member function(cpu function) as a function pointer to a cuda kernel. But, the main problem is, i can only pass the static class member function. I am not being able to pass the function not declared as static. For Example:

    /**/ __host__ __device__ static int CellfunPtr( void*ptr, int a ); /**/

    The above function work because this member function is declared as static member function. If i do not declare this member function as a static member as , /**/ __host__ __device__ int CellfunPtr( void*ptr, int a ); /**/

    then it doesnt work.

    The complete code has four files.


    1. First file

    /*start of fundef.h file*/

    typedef int (*pFunc_t)(void* ptr, int N);

    /*end of fundef.h file*/


    1. Second file

    /*start of solver.h file*/

        class CalcVars {
    
           int eqnCount;
           int numCell;                      
           int numTri;
           int numTet;
    
        public:
           double* cellVel; 
           double* cellPre;
    
        /** Constructor */
    
        CalcVars(
            const int eqnCount_,             
            const int numCell_,          
            const int numTri_,             
            const int numTet_                
        );
    
        /** Destructor */
    
        ~CalcVars(void);
    
        public:
    
          void 
              CalcAdv();
    
    
          __host__ __device__ 
          static int 
              CellfunPtr(
              void*ptr, int a
        );
    
        };
    

    /*end of solver.h file*/


    1. Third file

    /*start of solver.cu file*/

         #include "solver.h"
         __device__ pFunc_t pF1_d = CalcVars::CellfunPtr;
    
        pFunc_t pF1_h ;
    
    
        __global__ void kernel(int*a, pFunc_t func, void* thisPtr_){
            int tid = threadIdx.x;
            a[tid] = (*func)(thisPtr_, a[tid]); 
        };
    
        /* Constructor */
    
        CalcVars::CalcVars(
            const int eqnCount_,             
            const int numCell_,          
            const int numTri_,             
            const int numTet_   
    
        )
        {
            this->eqnCount = eqnCount_;
            this->numCell = numCell_;
            this->numTri = numTri_;
    
            this->cellVel = (double*) calloc((size_t) eqnCount, sizeof(double)); 
            this->cellPre = (double*) calloc((size_t) eqnCount, sizeof(double)); 
    
        }
    
        /* Destructor */
    
        CalcVars::~CalcVars(void)
        {
           free(this->cellVel);
           free(this->cellPre);
    
        }
    
    
        void 
        CalcVars::CalcAdv(
        ){
    
            /*int b1 = 0;
    
            b1 = CellfunPtr(this, 1);*/
    
           int Num = 50;
           int *a1, *a1_dev;
    
            a1 = (int *)malloc(Num*sizeof(int));
    
            cudaMalloc((void**)&a1_dev, Num*sizeof(int));
    
            for(int i = 0; i >>(a1_dev, pF1_h, this);
    
            cudaDeviceSynchronize();
    
            cudaMemcpy(a1, a1_dev, Num*sizeof(int), cudaMemcpyDeviceToHost);
    
    
        };
    
    
        int 
        CalcVars::CellfunPtr(
            void* ptr, int a
        ){
            //CalcVars* ClsPtr = (CalcVars*)ptr;
            printf("Printing from CPU function\n");
            //int eqn_size = ClsPtr->eqnCount;
            //printf("The number is %d",eqn_size);
            return a-1;
    
        };
    

    /*end of solver.cu file*/


    1. Fourth file

    /*start of main.cpp file*/

        #include "solver.h"
    
    
        int main(){
    
            int n_Eqn, n_cell, n_tri, n_tetra;
            n_Eqn = 100;
            n_cell = 200;
            n_tri = 300;
            n_tetra = 400;
    
           CalcVars* calcvars;
    
           calcvars = new CalcVars(n_Eqn, n_cell, n_tri, n_tetra );
    
           calcvars->CalcAdv();
    
           system("pause");
    
        }
    

    /*end of main.cpp file*/

提交回复
热议问题