Matrix Multiplication Using NumericMatrix and NumericVector in Rcpp

ぃ、小莉子 提交于 2019-12-17 20:25:09

问题


I am wondering is there a way of calculating matrix multiplication using NumericMatrix and NumericVector class. I am wondering if there is any simple way to help me avoid the following loop to conduct this calculation. I just want to calculate X%*%beta.

// assume X and beta are initialized and X is of dimension (nsites, p), 
// beta is a NumericVector with p elements. 
for(int j = 0; j < nsites; j++)
 {
    temp = 0;

    for(int l = 0; l < p; l++) temp = temp + X(j,l) * beta[l];

}

Thank you very much in advance!


回答1:


Building off of Dirk's comment, here are a few cases that demonstrate the Armadillo library's matrix multiplication via the overloaded * operator:

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export(".mm")]]
arma::mat mm_mult(const arma::mat& lhs,
                  const arma::mat& rhs)
{
  return lhs * rhs;
}

// [[Rcpp::export(".vm")]]
arma::mat vm_mult(const arma::vec& lhs,
                  const arma::mat& rhs)
{
  return lhs.t() * rhs;
}

// [[Rcpp::export(".mv")]]
arma::mat mv_mult(const arma::mat& lhs,
                  const arma::vec& rhs)
{
  return lhs * rhs;
}

// [[Rcpp::export(".vv")]]
arma::mat vv_mult(const arma::vec& lhs,
                  const arma::vec& rhs)
{
  return lhs.t() * rhs;
}

You could then define an R function to dispatch the appropriate C++ function:

`%a*%` <- function(x,y) {

  if (is.matrix(x) && is.matrix(y)) {
    return(.mm(x,y))
  } else if (!is.matrix(x) && is.matrix(y)) {
    return(.vm(x,y))
  } else if (is.matrix(x) && !is.matrix(y)) {
    return(.mv(x,y))
  } else {
    return(.vv(x,y))
  }

}
##
mx <- matrix(1,nrow=3,ncol=3)
vx <- rep(1,3)
my <- matrix(.5,nrow=3,ncol=3)
vy <- rep(.5,3)

And comparing to R's %*% function:

R>  mx %a*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
[2,]  1.5  1.5  1.5
[3,]  1.5  1.5  1.5

R>  mx %*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
[2,]  1.5  1.5  1.5
[3,]  1.5  1.5  1.5
##
R>  vx %a*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5

R>  vx %*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
##
R>  mx %a*% vy
     [,1]
[1,]  1.5
[2,]  1.5
[3,]  1.5

R>  mx %*% vy
     [,1]
[1,]  1.5
[2,]  1.5
[3,]  1.5
##
R>  vx %a*% vy
     [,1]
[1,]  1.5

R>  vx %*% vy
     [,1]
[1,]  1.5


来源:https://stackoverflow.com/questions/28465766/matrix-multiplication-using-numericmatrix-and-numericvector-in-rcpp

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!