Rcpp: my distance matrix program is slower than the function in package

前端 未结 2 1289
孤独总比滥情好
孤独总比滥情好 2020-12-21 17:33

I would like to calculate the pairwise euclidean distance matrix. I wrote Rcpp programs by the suggestion of Dirk Eddelbuettel as follows

Nu         


        
2条回答
  •  野趣味
    野趣味 (楼主)
    2020-12-21 18:37

    You were almost there. But your inner loop body tried to do too much in one line. Template programming is hard enough as it is, and sometimes it is just better to spread instructions out a little to give the compiler a better chance. So I just made it five statements, and built immediatelt.

    New code:

    #include 
    
    using namespace Rcpp;
    
    double dist1 (NumericVector x, NumericVector y){
      int n = y.length();
      double total = 0;
      for (int i = 0; i < n ; ++i) {
        total += pow(x(i)-y(i),2.0);
      }
      total = sqrt(total);
      return total;
    }
    
    // [[Rcpp::export]]
    NumericMatrix calcPWD (NumericMatrix x){
      int outrows = x.nrow();
      int outcols = x.nrow();
      NumericMatrix out(outrows,outcols);
    
      for (int i = 0 ; i < outrows - 1; i++){
        for (int j = i + 1  ; j < outcols ; j ++) {
          NumericVector v1 = x.row(i);
          NumericVector v2 = x.row(j-1);
          double d = dist1(v1, v2);
          out(j-1,i) = d;
          out(i,j-1)= d;
        }
      }
      return (out) ;
    }
    
    /*** R
    M <- matrix(log(1:9), 3, 3)
    calcPWD(M)
    */
    

    Running it:

    R> sourceCpp("/tmp/mikebrown.cpp")
    
    R> M <- matrix(log(1:9), 3, 3)
    
    R> calcPWD(M)
             [,1]     [,2] [,3]
    [1,] 0.000000 0.740322    0
    [2,] 0.740322 0.000000    0
    [3,] 0.000000 0.000000    0
    R> 
    

    You may want to check your indexing logic though. Looks like you missed more comparisons.

    Edit: For kicks, here is a more compact version of your distance function:

    // [[Rcpp::export]]
    double dist2(NumericVector x, NumericVector y){
      double d = sqrt( sum( pow(x - y, 2) ) );
      return d;
    }
    

提交回复
热议问题