Why is this naive matrix multiplication faster than base R's?

前端 未结 3 1141
生来不讨喜
生来不讨喜 2020-12-14 06:40

In R, matrix multiplication is very optimized, i.e. is really just a call to BLAS/LAPACK. However, I\'m surprised this very naive C++ code for matrix-vector multiplication s

3条回答
  •  谎友^
    谎友^ (楼主)
    2020-12-14 07:04

    Josh's answer explains why R's matrix multiplication is not as fast as this naive approach. I was curious to see how much one could gain using RcppArmadillo. The code is simple enough:

    arma_code <- 
      "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
           return m * v;
       };"
    arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
    

    Benchmark:

    > microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
    Unit: milliseconds
              expr      min       lq      mean    median        uq       max neval
       my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
           m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
     arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10
    

    So RcppArmadillo gives us nicer syntax and better performance.

    Curiosity got the better of me. Here a solution for using BLAS directly:

    blas_code = "
    NumericVector blas_mm(NumericMatrix m, NumericVector v){
      int nRow = m.rows();
      int nCol = m.cols();
      NumericVector ans(nRow);
      char trans = 'N';
      double one = 1.0, zero = 0.0;
      int ione = 1;
      F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
               &ione, &zero, ans.begin(), &ione);
      return ans;
    }"
    blas_mm <- cppFunction(code = blas_code, includes = "#include ")
    

    Benchmark:

    Unit: milliseconds
              expr      min       lq      mean    median        uq       max neval
       my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
           m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
     arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
     blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10
    

    Armadillo and BLAS (OpenBLAS in my case) are almost the same. And the BLAS code is what R does in the end as well. So 2/3 of what R does is error checking etc.

提交回复
热议问题