Why is this naive matrix multiplication faster than base R's?

馋奶兔 提交于 2019-11-28 21:09:03

A quick glance in names.c (here in particular) points you to do_matprod, the C function that is called by %*% and which is found in the file array.c. (Interestingly, it turns out, that both crossprod and tcrossprod dispatch to that same function as well). Here is a link to the code of do_matprod.

Scrolling through the function, you can see that it takes care of a number of things your naive implementation does not, including:

  1. Keeps row and column names, where that makes sense.
  2. Allows for dispatch to alternative S4 methods when the two objects being operated on by a call to %*% are of classes for which such methods have been provided. (That's what's happening in this portion of the function.)
  3. Handles both real and complex matrices.
  4. Implements a series of rules for how to handle multiplication of a matrix and a matrix, a vector and a matrix, a matrix and a vector, and a vector and a vector. (Recall that under cross-multiplication in R, a vector on the LHS is treated as a row vector, whereas on the RHS, it is treated as a column vector; this is the code that makes that so.)

Near the end of the function, it dispatches to either of matprod or or cmatprod. Interestingly (to me at least), in the case of real matrices, if either matrix might contain NaN or Inf values, then matprod dispatches (here) to a function called simple_matprod which is about as simple and straightforward as your own. Otherwise, it dispatches to one of a couple of BLAS Fortran routines which, presumably are faster, if uniformly 'well-behaved' matrix elements can be guaranteed.

Josh's answer explains why R's matrix multiplication is not as fast as this naive approach. I was curious to see how much one could gain using RcppArmadillo. The code is simple enough:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

Benchmark:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

So RcppArmadillo gives us nicer syntax and better performance.

Curiosity got the better of me. Here a solution for using BLAS directly:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

Benchmark:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

Armadillo and BLAS (OpenBLAS in my case) are almost the same. And the BLAS code is what R does in the end as well. So 2/3 of what R does is error checking etc.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!