dplyr - filter by group size

后端 未结 6 1907
渐次进展
渐次进展 2020-11-28 15:12

What is the best way to filter a data.frame to only get groups of say size 5?

So my data looks as follows:

require(dplyr)
n <- 1e5
x <- rnorm(n         


        
相关标签:
6条回答
  • 2020-11-28 15:45

    Here's another dplyr approach you can try

    semi_join(dat, count(dat, cat) %>% filter(n == 5), by = "cat")
    

    --

    Here's another approach based on OP's original approach with a little modification:

    n <- 1e5
    x <- rnorm(n)
    # Category size ranging each from 1 to 5
    cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
    
    dat <- data.frame(x = x, cat = cat)
    
    # second data set for the dt approch
    dat2 <- data.frame(x = x, cat = cat)
    
    sol_floo0 <- function(dat){
      dat <- group_by(dat, cat)
      all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
      take_only <- which(group_size(dat) == 5L)
      dat[all_ind %in% take_only, ]
    }
    
    sol_floo0_v2 <- function(dat){
      g <- group_by(dat, cat) %>% group_size()
      ind <- rep(g == 5, g)
      dat[ind, ]
    }
    
    
    
    microbenchmark::microbenchmark(times = 10,
                                   sol_floo0(dat),
                                   sol_floo0_v2(dat2))
    #Unit: milliseconds
    #               expr      min       lq     mean   median       uq      max neval cld
    #     sol_floo0(dat) 43.72903 44.89957 45.71121 45.10773 46.59019 48.64595    10   b
    # sol_floo0_v2(dat2) 29.83724 30.56719 32.92777 31.97169 34.10451 38.31037    10  a 
    all.equal(sol_floo0(dat), sol_floo0_v2(dat2))
    #[1] TRUE
    
    0 讨论(0)
  • 2020-11-28 15:51

    A very simple way of accelerating the dplyr-way n() filter is to store the result in a new column. The initial time of calculating the group size is amortised if there are multiple filters later on.

    library(dplyr)
    
    prep_group <- function(dat) {
        dat %>%
            group_by(cat) %>%
            mutate(
                Occurrences = n()
            ) %>%
            ungroup()
    }
    
    # Create a new data frame with the `Occurrences` column:
    # dat_prepped <- dat %>% prep_group
    

    Filtering the Occurrences field is much faster than the workaround solution:

    sol_floo0 <- function(dat){
        dat <- group_by(dat, cat)
        all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
        take_only <- which(group_size(dat) == 5L)
        dat[all_ind %in% take_only, ]
    }
    
    sol_floo0_v2 <- function(dat){
        g <- group_by(dat, cat) %>% group_size()
        ind <- rep(g == 5, g)
        dat[ind, ]
    }
    
    sol_cached <- function(dat) {
        out <- filter(dat, Occurrences == 5L)
    }
    
    n <- 1e5
    x <- rnorm(n)
    # Category size ranging each from 1 to 5
    cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
    
    dat <- data.frame(x = x, cat = cat)
    
    dat_prepped <- prep_group(dat)
    
    microbenchmark::microbenchmark(times=50, sol_floo0(dat), sol_floo0_v2(dat), sol_cached(dat_prepped))
    
    Unit: microseconds
                        expr       min        lq      mean    median        uq        max neval cld
              sol_floo0(dat) 33345.764 35603.446 42430.441 37994.477 41379.411 144103.471    50   c
           sol_floo0_v2(dat) 26180.539 27842.927 29694.203 29089.672 30997.411  37412.899    50  b 
     sol_cached(dat_prepped)   801.402   930.025  1342.348  1098.843  1328.192   5049.895    50 a  
    

    The preparation can be further accelerated by using count() -> left_join():

    prep_join <- function(dat) {
        dat %>%
            left_join(
                dat %>%
                    count(cat, name="Occurrences")
            )
    }
    
    microbenchmark::microbenchmark(times=10, prep_group(dat), prep_join(dat))
    
    Unit: milliseconds
                expr      min       lq     mean   median       uq      max neval cld
     prep_group(dat) 45.67805 47.68100 48.98929 49.11258 50.08214 52.44737    10   b
      prep_join(dat) 35.01945 36.20857 37.96460 36.86776 38.71056 45.59041    10  a 
    
    0 讨论(0)
  • 2020-11-28 15:53

    I know you asked for a dplyr solution but if you combine it with some purrr you can get it in one line without specifying any new functions. (A little slower though.)

    library(dplyr)
    library(purrr)
    library(tidyr)
    
    dat %>% 
      group_by(cat) %>% 
      nest() %>% 
      mutate(n = map(data, n_distinct)) %>%
      unnest(n = n) %>% 
      filter(n == 5) %>% 
      select(cat, n)
    
    0 讨论(0)
  • 2020-11-28 15:53

    Comparing the answers timewise:

    require(dplyr)
    require(data.table)
    n <- 1e5
    x <- rnorm(n)
    # Category size ranging each from 1 to 5
    cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
    
    dat <- data.frame(x = x, cat = cat)
    
    # second data set for the dt approch
    dat2 <- data.frame(x = x, cat = cat)
    
    sol_floo0 <- function(dat){
      dat <- group_by(dat, cat)
      all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
      take_only <- which(group_size(dat) == 5L)
      dat[all_ind %in% take_only, ]
    }
    
    sol_floo0_v2 <- function(dat){
      g <- group_by(dat, cat) %>% group_size()
      ind <- rep(g == 5, g)
      dat[ind, ]
    }
    
    sol_docendo_discimus <- function(dat){ 
      dat <- group_by(dat, cat)
      semi_join(dat, count(dat, cat) %>% filter(n == 5), by = "cat")
    }
    
    sol_akrun <- function(dat2){
      setDT(dat2)[dat2[, .I[.N==5], by = cat]$V1]
    }
    
    sol_sotos <- function(dat2){
      setDT(dat2)[, if(.N == 5) .SD, by = cat]
    }
    
    sol_chirayu_chamoli <- function(dat){
      rle_ <- rle(dat$cat)
      dat[dat$cat %in% rle_$values[rle_$lengths==5], ]
    }
    
    microbenchmark::microbenchmark(times = 20,
                                   sol_floo0(dat),
                                   sol_floo0_v2(dat),
                                   sol_docendo_discimus(dat), 
                                   sol_akrun(dat2),
                                   sol_sotos(dat2),
                                   sol_chirayu_chamoli(dat))
    

    Results in:

    Unit: milliseconds
                          expr       min        lq      mean    median        uq       max neval  cld
                sol_floo0(dat)  58.00439  65.28063  93.54014  69.82658  82.79997 280.23114    20   cd
             sol_floo0_v2(dat)  42.27791  50.27953  72.51729  58.63931  67.62540 238.97413    20  bc 
     sol_docendo_discimus(dat) 100.54095 113.15476 126.74142 121.69013 132.62533 183.05818    20    d
               sol_akrun(dat2)  26.88369  34.01925  41.04378  37.07957  45.44784  63.95430    20 ab  
               sol_sotos(dat2)  16.10177  19.78403  24.04375  23.06900  28.05470  35.83611    20 a   
      sol_chirayu_chamoli(dat)  20.67951  24.18100  38.01172  27.61618  31.97834 230.51026    20 ab  
    
    0 讨论(0)
  • 2020-11-28 15:54

    I generalised the function written by docendo discimus, to use it alongside existing dplyr functions:

    #' inherit dplyr::filter
    #' @param min minimal group size, use \code{min = NULL} to filter on maximal group size only
    #' @param max maximal group size, use \code{max = NULL} to filter on minimal group size only
    #' @export
    #' @source Stack Overflow answer by docendo discimus, \url{https://stackoverflow.com/a/43110620/4575331}
    filter_group_size <- function(.data, min = NULL, max = min) {
      g <- dplyr::group_size(.data)
      if (is.null(min) & is.null(max)) {
        stop('`min` and `max` cannot both be NULL.')
      }
      if (is.null(max)) {
        max <- base::max(g, na.rm = TRUE)
      }
      ind <- base::rep(g >= min & g <= max, g)
      .data[ind, ]
    }
    

    Let's check it for a minimal group size of 5:

    dat2 %>%
      group_by(cat) %>%
      filter_group_size(5, NULL) %>%
      summarise(n = n()) %>%
      arrange(desc(n))
    
    # # A tibble: 6,634 x 2
    #      cat     n
    #    <int> <int>
    #  1    NA    19
    #  2     1     5
    #  3     2     5
    #  4     6     5
    #  5    15     5
    #  6    17     5
    #  7    21     5
    #  8    27     5
    #  9    33     5
    # 10    37     5
    # # ... with 6,624 more rows
    

    Great, now check for the OP's question; a group size of exactly 5:

    dat2 %>%
      group_by(cat) %>%
      filter_group_size(5) %>%
      summarise(n = n()) %>%
      pull(n) %>%
      unique()
    # [1] 5
    

    Hooray.

    0 讨论(0)
  • 2020-11-28 16:04

    You can do it more concisely with n():

    library(dplyr)
    dat %>% group_by(cat) %>% filter(n() == 5)
    
    0 讨论(0)
提交回复
热议问题