Fastest way to compute row-wise dot products between two skinny tall matrices in R

前端 未结 1 1109
星月不相逢
星月不相逢 2020-12-21 20:47

Consider A and B are two tall skinny matrices of dimension 10^8 X 5. i.e.;

r=10^8
c=5
A=matrix(runif(r*c,0,1),r,c)
B=matrix(runif(r*c,0,1),r,c)
相关标签:
1条回答
  • 2020-12-21 21:17

    This might disappoint you but at R level, this is already the best you can get without writing some C code yourself. The problem is that by doing rowSums(A * B), you are effectively doing

    C <- A * B
    rowSums(C)
    

    The first line performs a full scan of three large tall-thin matrices; while the second line performs a full scan of 1 large tall-thin matrix. So altogether, we equivalently scan a tall-thin matrix 4 times (memory intensive).

    In fact, for such operation, the optimal algorithm only needs scanning a n * p tall-thin matrix twice, by doing rowwise cross product:

    rowsum <- numeric(n)
    for j = 1, 2, ... p
      rowsum += A[,i] * B[,i]
    

    In this way, we also avoid generating matrix C. Note, the above is just a fake code rather than valid R code or even C code. But the idea is clear, and we want to program this in C.


    An analogy to your situation is the speed difference between sum(x * y) and crossprod(x, y), assuming x and y be large vectors of the same length.

    x <- runif(1e+7)
    y <- runif(1e+7)
    
    system.time(sum(x * y))
    #   user  system elapsed 
    #  0.124   0.032   0.158 
    
    system.time(crossprod(x, y))
    #   user  system elapsed 
    #  0.036   0.000   0.056 
    

    In the first case, we scan a long vector 4 times, while in the second case, we only scan it twice.


    Relevance in statistical computing

    rowSums(A * B) is in fact an efficient evaluation of diag(tcrossprod(A, B)), commonly seen in regression computing associated with point-wise prediction variance. For example, in ordinary linear squares regression with thin Q matrix from QR factorization of model matrix, the point-wise variance of fitted values are diag(tcrossprod(Q)), which is more efficiently computed by rowSums(Q ^ 2). But yet, this is still not the fastest evaluation, for reasons already explained.

    0 讨论(0)
提交回复
热议问题