QR decomposition to solve linear systems in CUDA

前端 未结 3 420
离开以前
离开以前 2020-12-29 16:06

I\'m writing an image restoration algorithm on GPU, details in

Cuda: least square solving , poor in speed

The QR decomposition method to solve the linear sy

3条回答
  •  遥遥无期
    2020-12-29 16:51

    The following code is a slight expansion of JackOLantern's answer for a general M-by-K input RHS matrix b. Basically you need to copy the upper matrix for R and intermediate b so that matrices have the right stride.

    #include 
    #include 
    #include 
    #include 
    #include "cuda_runtime.h"
    #include "cublas_v2.h"
    #include "cusolverDn.h"
    #include "cublas_test.h"
    #include "Eigen/Dense"
    #include "gpu_util.h"
    //##############################################################################
    template
    void PrintEMatrix(const T &mat, const char *name) {
        std::cout << name << " =\n";
        std::cout << mat << std::endl;
    }
    //##############################################################################
    template
    __global__
    void Ker_CopyUpperSubmatrix(const T *__restrict d_in,
                                      T *__restrict d_ou,
                                const int M, const int N, const int subM) {
        const int i = threadIdx.x + blockIdx.x*blockDim.x;
        const int j = threadIdx.y + blockIdx.y*blockDim.y;
        if (i>=subM || j>=N)
            return;
        d_ou[j*subM+i] = d_in[j*M+i];
    }
    //##############################################################################
    int TestQR() {
        typedef double T; // NOTE: don't change this. blas has different func name
        typedef Eigen::Matrix MatrixXd;
        typedef Eigen::Matrix VectorXd;
    
        // define handles
        cusolverDnHandle_t cusolverH = NULL;
        cublasHandle_t cublasH = NULL;
    
        const int M = 3;
        const int N = 2;
        const int K = 5;
    
        MatrixXd A;
        A = MatrixXd::Random(M,N);
        MatrixXd x_ref, x_sol;
        x_sol.resize(N,K);
        x_ref = MatrixXd::Random(N,K);
        MatrixXd b = A*x_ref;
    
        PrintEMatrix(A, "A");
        PrintEMatrix(b, "b");
        PrintEMatrix(x_ref, "x_ref");
    
    #define CUSOLVER_ERRCHK(x) \
        assert(x == CUSOLVER_STATUS_SUCCESS && "cusolver failed");
    #define CUBLAS_ERRCHK(x) \
        assert(x == CUBLAS_STATUS_SUCCESS && "cublas failed");
    
        CUSOLVER_ERRCHK(cusolverDnCreate(&cusolverH));
        CUBLAS_ERRCHK(cublasCreate(&cublasH));
    
        T *d_A, *d_b, *d_work, *d_work2, *d_tau;
        int *d_devInfo, devInfo;
        gpuErrchk(cudaMalloc((void**)&d_A, sizeof(T)*M*N));
        gpuErrchk(cudaMalloc((void**)&d_b, sizeof(T)*M*K));
        gpuErrchk(cudaMalloc((void**)&d_tau, sizeof(T)*M));
        gpuErrchk(cudaMalloc((void**)&d_devInfo, sizeof(int)));
        gpuErrchk(cudaMemcpy(d_A, A.data(), sizeof(T)*M*N, cudaMemcpyHostToDevice));
        gpuErrchk(cudaMemcpy(d_b, b.data(), sizeof(T)*M*K, cudaMemcpyHostToDevice));
        int bufSize,bufSize2;
    
        // in-place A = QR
        CUSOLVER_ERRCHK(
            cusolverDnDgeqrf_bufferSize(
                cusolverH,
                M,
                N,
                d_A,
                M,
                &bufSize
            )
        );
        gpuErrchk(cudaMalloc((void**)&d_work, sizeof(T)*bufSize));
        CUSOLVER_ERRCHK(
            cusolverDnDgeqrf(
                cusolverH,
                M,
                N,
                d_A,
                M,
                d_tau,
                d_work,
                bufSize,
                d_devInfo
            )
        );
        gpuErrchk(cudaMemcpy(&devInfo, d_devInfo, sizeof(int),
            cudaMemcpyDeviceToHost));
        assert(0 == devInfo && "QR factorization failed");
    
        // Q^T*b
        CUSOLVER_ERRCHK(                                                                                                                                                                                                                                                                  
            cusolverDnDormqr_bufferSize(                                        
                cusolverH,                                                      
                CUBLAS_SIDE_LEFT,                                               
                CUBLAS_OP_T,                                                    
                M,                                                              
                K,                                                              
                N,                                                              
                d_A,                                                            
                M,                                                              
                d_tau,                                                          
                d_b,                                                            
                M,                                                              
                &bufSize2                                                       
            )                                                                   
        );                                                                      
        gpuErrchk(cudaMalloc((void**)&d_work2, sizeof(T)*bufSize2));            
        CUSOLVER_ERRCHK(                                                        
            cusolverDnDormqr(                                                   
                cusolverH,                                                      
                CUBLAS_SIDE_LEFT,                                               
                CUBLAS_OP_T,                                                    
                M,                                                              
                K,                                                              
                min(M,N),                                                       
                d_A,                                                            
                M,                                                              
                d_tau,                                                          
                d_b,                                                            
                M,                                                              
                d_work2,                                                        
                bufSize2,                                                       
                d_devInfo                                                       
            )                                                                   
        );
        gpuErrchk(cudaDeviceSynchronize());
        gpuErrchk(cudaMemcpy(&devInfo, d_devInfo, sizeof(int),
            cudaMemcpyDeviceToHost));
        assert(0 == devInfo && "Q^T b failed");
    
        // need to explicitly copy submatrix for the triangular solve
        T *d_R, *d_b_;
        gpuErrchk(cudaMalloc((void**)&d_R, sizeof(T)*N*N));
        gpuErrchk(cudaMalloc((void**)&d_b_,sizeof(T)*N*K));
        dim3 thd_size(32,32);
        dim3 blk_size((N+thd_size.x-1)/thd_size.x,(N+thd_size.y-1)/thd_size.y);
        Ker_CopyUpperSubmatrix<<>>(d_A, d_R, M, N, N);
        blk_size = dim3((N+thd_size.x-1)/thd_size.x,(K+thd_size.y-1)/thd_size.y);
        Ker_CopyUpperSubmatrix<<>>(d_b, d_b_, M, K, N);
    
        // solve x = R \ (Q^T*B)
        const double one = 1.0;
        CUBLAS_ERRCHK(
            cublasDtrsm(
                cublasH,
                CUBLAS_SIDE_LEFT,
                CUBLAS_FILL_MODE_UPPER,
                CUBLAS_OP_N,
                CUBLAS_DIAG_NON_UNIT,
                N,
                K,
                &one,
                d_R,
                N,
                d_b_,
                N
            )
        );
        gpuErrchk(cudaDeviceSynchronize());
    
        gpuErrchk(cudaMemcpy(x_sol.data(), d_b_, sizeof(T)*N*K,
            cudaMemcpyDeviceToHost));
    
        PrintEMatrix(x_ref, "x_ref");
        PrintEMatrix(x_sol, "x_sol");
        std::cout << "solution l2 error = " << (x_ref-x_sol).norm()
                  << std::endl;
    
        exit(0);
        return 0;
    }
    //##############################################################################
    

提交回复
热议问题