dplyr - filter by group size

后端 未结 6 1908
渐次进展
渐次进展 2020-11-28 15:12

What is the best way to filter a data.frame to only get groups of say size 5?

So my data looks as follows:

require(dplyr)
n <- 1e5
x <- rnorm(n         


        
6条回答
  •  無奈伤痛
    2020-11-28 15:51

    A very simple way of accelerating the dplyr-way n() filter is to store the result in a new column. The initial time of calculating the group size is amortised if there are multiple filters later on.

    library(dplyr)
    
    prep_group <- function(dat) {
        dat %>%
            group_by(cat) %>%
            mutate(
                Occurrences = n()
            ) %>%
            ungroup()
    }
    
    # Create a new data frame with the `Occurrences` column:
    # dat_prepped <- dat %>% prep_group
    

    Filtering the Occurrences field is much faster than the workaround solution:

    sol_floo0 <- function(dat){
        dat <- group_by(dat, cat)
        all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
        take_only <- which(group_size(dat) == 5L)
        dat[all_ind %in% take_only, ]
    }
    
    sol_floo0_v2 <- function(dat){
        g <- group_by(dat, cat) %>% group_size()
        ind <- rep(g == 5, g)
        dat[ind, ]
    }
    
    sol_cached <- function(dat) {
        out <- filter(dat, Occurrences == 5L)
    }
    
    n <- 1e5
    x <- rnorm(n)
    # Category size ranging each from 1 to 5
    cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
    
    dat <- data.frame(x = x, cat = cat)
    
    dat_prepped <- prep_group(dat)
    
    microbenchmark::microbenchmark(times=50, sol_floo0(dat), sol_floo0_v2(dat), sol_cached(dat_prepped))
    
    Unit: microseconds
                        expr       min        lq      mean    median        uq        max neval cld
              sol_floo0(dat) 33345.764 35603.446 42430.441 37994.477 41379.411 144103.471    50   c
           sol_floo0_v2(dat) 26180.539 27842.927 29694.203 29089.672 30997.411  37412.899    50  b 
     sol_cached(dat_prepped)   801.402   930.025  1342.348  1098.843  1328.192   5049.895    50 a  
    

    The preparation can be further accelerated by using count() -> left_join():

    prep_join <- function(dat) {
        dat %>%
            left_join(
                dat %>%
                    count(cat, name="Occurrences")
            )
    }
    
    microbenchmark::microbenchmark(times=10, prep_group(dat), prep_join(dat))
    
    Unit: milliseconds
                expr      min       lq     mean   median       uq      max neval cld
     prep_group(dat) 45.67805 47.68100 48.98929 49.11258 50.08214 52.44737    10   b
      prep_join(dat) 35.01945 36.20857 37.96460 36.86776 38.71056 45.59041    10  a 
    

提交回复
热议问题