Fast subsetting of a matrix in R

淺唱寂寞╮ 提交于 2019-12-05 10:27:52

The problem with your solution is that the subsetting is allocating another matrix, which takes times.

You have two solutions:

If the time taken with sum on the whole matrix is okay with you, you could use colSums on the whole matrix and subset the result:

sum(colSums(m0)[1:900])

Or you could use Rcpp to compute the sum with subsetting without copying the matrix.

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
double sumSub(const NumericMatrix& x,
              const IntegerVector& colInd) {

  double sum = 0;

  for (IntegerVector::const_iterator it = colInd.begin(); it != colInd.end(); ++it) {
    int j = *it - 1;
    for (int i = 0; i < x.nrow(); i++) {
      sum += x(i, j);
    }
  }

  return sum;
}

    microbenchmark(m0[, 1:900], sum(m0[, 1:900]), sum(r0[,1:900]), sum(m0),
                   sum(colSums(m0)[1:900]),
                   sumSub(m0, 1:900))
Unit: milliseconds
                    expr      min       lq     mean   median       uq       max neval
             m0[, 1:900] 4.831616 5.447749 5.641096 5.675774 5.861052  6.418266   100
        sum(m0[, 1:900]) 6.103985 6.475921 7.052001 6.723035 6.999226 37.085345   100
        sum(r0[, 1:900]) 6.224850 6.449210 6.728681 6.705366 6.943689  7.565842   100
                 sum(m0) 1.110073 1.145906 1.175224 1.168696 1.197889  1.269589   100
 sum(colSums(m0)[1:900]) 1.113834 1.141411 1.178913 1.168312 1.201827  1.408785   100
       sumSub(m0, 1:900) 1.337188 1.368383 1.404744 1.390846 1.415434  2.459361   100

You could use unrolling optimization to further optimize the Rcpp version.

Using compiler I wrote a function that gets the result about 2x as fast as your other methods (8x that of sum(m0) instead of 16x):

require(compiler)

compiler_sum <- cmpfun({function(x) {
     tmp <- 0
     for (i in 1:900)
         tmp <- tmp+sum(x[,i])
     tmp
}})

microbenchmark( 
               sum(m0),
               compiler_sum(m0)
               )
Unit: milliseconds
             expr      min       lq     mean   median      uq       max
          sum(m0) 1.016532 1.056030 1.107263 1.084503 1.11173  1.634391
 compiler_sum(m0) 7.655251 7.854135 8.000521 8.021107 8.29850 16.760058
 neval
   100
   100
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!