Method to operate on each row of data.table without using apply function

可紊 提交于 2021-02-18 13:53:21

问题


I wrote a simple function below:

mcs <- function(v) { ifelse(sum((diff(sort(v)) > 6) > 0), NA, sd(v)) }

It is supposed to take a vector, sort it and then check if there is difference greater than 6 in each successive difference. It returns NA if there is a difference greater than 6 and the standard deviation if there is not.

I would like to apply this function across all rows of a data table (choosing only certain columns) and then append the return value for each row as a new column entry to the data table.

For example, given a data table like so

> dat <- data.table(A=c(1,2,3,4,5), B=c(2,3,4,10,6), C=c(3,4,10,6,8),   
D=c(3,3,3,3,3))  
> dat  
   A  B  C D  
1: 1  2  3 3  
2: 2  3  4 3  
3: 3  4 10 3  
4: 4 10  6 3  
5: 5  6  8 3  

I would like to generate the output below. (I applied function on column 2, 3, and 4 of each row.)

> dat
   A  B  C D        sd
1: 1  2  3 3 0.5773503
2: 2  3  4 3 0.5773503
3: 3  4 10 3 3.7859389
4: 4 10  6 3 3.5118846
5: 5  6  8 3 2.5166115

I learned that by row operation can be done with data tables using the following method:

> dat[, sd:=apply(.SD, 1, mcs), .SDcols=(c(2,3,4))]

And this method works except that it is too slow. I have to perform this operation on several large data tables and I wrote a script to do so. However, it only works for smaller data tables. For the tables with ~300,000 rows it finishes in a few seconds, but when I try to do so with a table that has ~800 million rows, my program doesn't finish. I've tried waiting for two hours and I think R breaks or something because the console just freezes. I've tried running the script several times and it always finishes the first few smaller tables correctly (I had the program write the table to a file to check) but when it reaches the large data table, it never finishes. I am running this on a computing cluster so I definitely don't think this is a hardware limitation. Probably poor code.

I am assuming the bottleneck is the looping done in apply, but I don't how to make it faster. I am pretty new to R so I am not sure how to optimize my code. I've seen a lot of posts around the Internet about vectorizing and I am thinking maybe if I could simultaneously apply my function to every row it would be much faster, but I don't know how to do that. Please help.

Edit
Sorry, I made a mistake in copying my mcs function. I have updated it.

Edit 2
For those interested, I ended up splitting the table in half and operating on each half separately and that worked for me.


回答1:


If you really need speed, as always it's best to move to C++ using Rcpp, which gives us a solution that's over 100x faster.

Data

I did make some different example data to test this on that had 1000 rows instead of 5:

set.seed(123)
dat <- data.table(A = rnorm(1e3, sd=4), B = rnorm(1e3, sd=4), C = rnorm(1e3, sd=4),
                  D = rnorm(1e3, sd=4), E = rnorm(1e3, sd=4))

Solution

I used the following C++ code to do the same thing as your function, but now the looping is done in C++ instead of R through apply which saves considerable time:

#include <Rcpp.h>

using namespace Rcpp;

// [[Rcpp::export]]
NumericVector mcs2(DataFrame x) {
    int n = x.nrows();
    int m = x.size();
    NumericMatrix mat(n, m);
    for ( int j = 0; j < m; ++j ) {
        mat(_, j) = NumericVector(x[j]);
    }
    NumericVector result(n);
    for ( int i = 0; i < n; ++i ) {
        NumericVector tmp = mat(i, _);
        std::sort(tmp.begin(), tmp.end());
        bool do_sd = true;
        for ( int j = 1; j < m; ++j ) {
            if ( tmp[j] - tmp[j-1] > 6.0 ) {
                result[i] = NA_REAL;
                do_sd = false;
                break;
            }
        }
        if ( do_sd ) {
            result[i] = sd(tmp);
        }
        do_sd = true;
    }
    return result;
}

We can make sure it's returning the same values:

all.equal(apply(dat[, 2:4], 1, mcs1), mcs2(dat[,2:4]))

[1] TRUE

Now let's benchmark:

benchmark(mcs1 = dat[, sd:=apply(.SD, 1, mcs1), .SDcols=(c(2,3,4))],
          mcs2 = dat[, sd:=mcs2(.SD), .SDcols=(c(2,3,4))],
          order = 'relative',
          columns = c('test', 'elapsed', 'relative', 'user.self'))


  test elapsed relative user.self
2 mcs2    0.19    1.000     0.183
1 mcs1   21.34  112.316    20.044

How to compile this code

As an introduction to using C++ code through Rcpp, I'd suggest this chapter of Hadley Wickham's Advanced R. If you intend on doing anything further with Rcpp I'd strongly recommend you also read the official documentation and vignettes, but Wickham's book is probably a little more beginner friendly to use as a starting point. For your purposes, you just need to get Rcpp up and running so that you can compile the code above.

For this code to work for you, you'll need the Rcpp package if you don't already have it. You can get the package by running

install.packages(Rcpp)

from R. Note you'll also need a compiler; if you're on a Debian-based Linux system such as Ubuntu you can run

sudo apt install r-base-dev

from the terminal. If you are on Mac or Windows, check here for some instructions on getting this set up, or in the Wickham chapter linked above.

Once you have Rcpp installed, save the C++ code above into a file. Let's say for our example the file is named "SOanswer.cpp". Then you can make its mcs2() function available from R by putting the following two lines in your R script:

library(Rcpp)
sourceCpp("SOanswer.cpp") # assuming the file is in your working directory

That's it! Now your R script can call mcs2() and run much faster. If you want to learn more about Rcpp, beside the Wickham chapter above, I'd check out the reference manual and the vignettes available here, this page from RStudio (which also has tons of links, some of which are linked to here), and you can also find some really useful stuff looking around the Rcpp gallery.



来源:https://stackoverflow.com/questions/47737230/method-to-operate-on-each-row-of-data-table-without-using-apply-function

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!