CUDA: Tiled matrix-matrix multiplication with shared memory and matrix size which is non-multiple of the block size

后端 未结 1 1351
爱一瞬间的悲伤
爱一瞬间的悲伤 2020-12-08 11:22

I\'m trying to familiarize myself with CUDA programming, and having a pretty fun time of it. I\'m currently looking at this pdf which deals with matrix multiplication, done

相关标签:
1条回答
  • 2020-12-08 12:05

    When the matrix dimensions are not multiples of the tile dimensions, then it can happen that some tiles cover the matrices only partially. The tile elements falling outside the not-fully overlapping tiles should be properly zero-ed. So, extending your code to arbitrarly sized matrices is easy, but does not amount at a simple index check. Below, I'm copying and pasting my version of the tiled matrix-matrix multiplication kernel with arbitrarily sized matrices

    __global__ void MatMul(float* A, float* B, float* C, int ARows, int ACols, int BRows,
        int BCols, int CRows, int CCols)
    {
        float CValue = 0;
    
        int Row = blockIdx.y*TILE_DIM + threadIdx.y;
        int Col = blockIdx.x*TILE_DIM + threadIdx.x;
    
        __shared__ float As[TILE_DIM][TILE_DIM];
        __shared__ float Bs[TILE_DIM][TILE_DIM];
    
        for (int k = 0; k < (TILE_DIM + ACols - 1)/TILE_DIM; k++) {
    
             if (k*TILE_DIM + threadIdx.x < ACols && Row < ARows)
                 As[threadIdx.y][threadIdx.x] = A[Row*ACols + k*TILE_DIM + threadIdx.x];
             else
                 As[threadIdx.y][threadIdx.x] = 0.0;
    
             if (k*TILE_DIM + threadIdx.y < BRows && Col < BCols)
                 Bs[threadIdx.y][threadIdx.x] = B[(k*TILE_DIM + threadIdx.y)*BCols + Col];
             else
                 Bs[threadIdx.y][threadIdx.x] = 0.0;
    
             __syncthreads();
    
             for (int n = 0; n < TILE_DIM; ++n)
                 CValue += As[threadIdx.y][n] * Bs[n][threadIdx.x];
    
             __syncthreads();
        }
    
        if (Row < CRows && Col < CCols)
            C[((blockIdx.y * blockDim.y + threadIdx.y)*CCols) +
               (blockIdx.x * blockDim.x)+ threadIdx.x] = CValue;
    }
    
    0 讨论(0)
提交回复
热议问题