subsetting a data.table based on a named list

若如初见. 提交于 2021-02-07 12:48:20

问题


I'm trying to subset a given data.table

DT <- data.table(
  a = c(1:20),
  b = (3:4),
  c = (5:14),
  d = c(1:4)
)

within a function by a parameter which is a named list

param <- list(a = 1:10,
              b = 2:3,
              c = c(5, 7, 10))

I am maybe a bit stuck here but I certainly do not want implement something ugly like this. Especially since its not very dynamic.

DT[(if (!is.null(param$a))
  a %in% param$a
  else
    TRUE)
  &
    (if (!is.null(param$b))
      b %in% param$b
     else
       TRUE)
  &
    (if (!is.null(param$c))
      c %in%  param$c
     else
       TRUE)
  &
    (if (!is.null(param$d))
      d %in% param$d
     else
       TRUE)]
   a b c d
1: 1 3 5 1
2: 3 3 7 3

Any ideas how to achieve this in an elegant way in data.table or base R using the names of the named list to subset the corresponding columns in the data.table with the associate values? Thanks!

EDIT

I performed a microbenchmark with some of the answers:

func_4 <- function(myp, DT) {
  myp    = Filter(Negate(is.null), param)

  exs = Map(function(var, val)
    call("%in%", var, val),
    var = sapply(names(myp), as.name),
    val = myp)
  exi = Reduce(function(x, y)
    call("&", x, y), exs)
  ex = call("[", x = as.name("DT"), i = exi)
  # eval(as.call(c(as.list(ex))))
  eval(ex)
}

microbenchmark(
  (DT[do.call(pmin, Map(`%in%`, DT[, names(param), with = FALSE], param)) == 1L]),
  (DT[rowSums(mapply(`%in%`, DT[, names(param), with = FALSE], param)) == length(param)]),
  (DT[do.call(CJ, param), on = names(param), nomatch = NULL]),
  (DT[expand.grid(param), on = names(param), nomatch = NULL]),
  (DT[DT[, all(mapply(`%in%`, .SD, param)), by = 1:nrow(DT), .SDcols = names(param)]$V1]),
  (func_4(myp = param, DT = DT)),
  times = 200)

   min        lq      mean   median        uq       max neval
  446.656  488.5365  565.5597  511.403  533.7785  7167.847   200
  454.120  516.3000  566.8617  538.146  561.8965  1840.982   200
 2433.450 2538.6075 2732.4749 2606.986 2704.5285 10302.085   200
 2478.595 2588.7240 2939.8625 2642.311 2743.9375 10722.578   200
 2648.707 2761.2475 3040.4926 2814.177 2903.8845 10334.822   200
 3243.040 3384.6220 3764.5087 3484.423 3596.9140 14873.898   200

回答1:


We can select columns in DT using names in param, apply %in% to every list element with columns and select only rows where all the values are TRUE.

DT[which(rowSums(mapply(`%in%`, DT[, names(param), with = FALSE],
      param)) == length(param)), ]

#   a b c d
#1: 1 3 5 1
#2: 3 3 7 3



回答2:


You can use the CJ (Cross Join) function from data.table to make a filtering table from the list.

lookup <- do.call(CJ, param)
head(lookup)
#    a b  c
# 1: 1 2  5
# 2: 1 2  7
# 3: 1 2 10
# 4: 1 3  5
# 5: 1 3  7
# 6: 1 3 10

DT[
    lookup,
    on = names(lookup),
    nomatch = NULL
]
#    a b c d
# 1: 1 3 5 1
# 2: 3 3 7 3

Note that nomatch = 0 means any combo in lookup that doesn't exist in DT won't return a row.




回答3:


Using Map we can do

DT[DT[, all(Map(`%in%`, .SD, param)), by = 1:nrow(DT)]$V1]
#   a b c d
#1: 1 3 5 1
#2: 3 3 7 3

For each row we check if all elements in DT are present in param.


Thanks to @Frank, this can be improved to

DT[DT[, all(mapply(`%in%`, .SD, param)), by = 1:nrow(DT), .SDcols=names(param)]$V1]



回答4:


You could build the expression with call(fun, ...) and as.name:

myp    = Filter(Negate(is.null), param)

exs = Map(function(var, val) call("%in%", var, val), var = sapply(names(myp), as.name), val = myp)
exi = Reduce(function(x,y) call("&", x, y), exs)
ex = call("[", x = as.name("DT"), i = exi)
# DT[i = a %in% 1:10 & b %in% 2:3 & c %in% c(5, 7, 10)]

eval(ex)
#    a b c d
# 1: 1 3 5 1
# 2: 3 3 7 3

By composing the call correctly, you can take advantage of efficient algorithms for "indices" in the data.table (see the package vignettes). You can also turn verbose on to get a note about the inefficiency of specifying param$c as numeric when DT$c is int:

> z <- as.call(c(as.list(ex), verbose=TRUE))
> eval(z)
Optimized subsetting with index 'c__b__a'
on= matches existing index, using index
Coercing double column i.'c' to integer to match type of x.'c'. Please avoid coercion for efficiency.
Starting bmerge ...done in 0.020sec 
   a b c d
1: 1 3 5 1
2: 3 3 7 3

That is, you should use c(5L, 7L, 10L).

A join, as in Nathan's answer, also uses indices, but building and joining on the Cartesian table of param will be costly if prod(lengths(param)) is large.


@markus approach may be slow due to by-row operation, so here is a variant:

DT[do.call(pmin, Map(`%in%`, DT[, names(param), with=FALSE], param)) == 1L]

#    a b c d
# 1: 1 3 5 1
# 2: 3 3 7 3

The trick is that the elementwise version of all is pmin(...) == 1L. Likewise, any corresponds to pmax(...) == 1L. (This is why pany/pall are not included in this conversation on r-devel: http://r.789695.n4.nabble.com/There-is-pmin-and-pmax-each-taking-na-rm-how-about-psum-td4647841.html)




回答5:


I am adding another answer because the solutions presented by the OP are missing a critical detail: how each one scales with large datasets. I frequently work with datasets with well over 1m records, so for my own benefit I performed the microbenchmarking experiment the OP presents using datasets of different sizes for the pmin + %in% + Map solution and the CJ solution, a version of which I had implemented independently. Although the former is markedly faster for small datasets, the latter scales much better:

It looks to me like the point where the relative speed switches is at ~200k records, regardless of the number of fields to subset on, so I packaged both functions into one for future use:

subsel <- function(x, sub, sel = NULL,
                   nomatch = getOption('datatable.nomatch')){
  #' function to subset data.table (x) using a named list (sub). sel
  #' can be used to return only the specified columns. algorithms
  #' copied from https://stackoverflow.com/questions/55728200/subsetting-a-data-table-based-on-a-named-list
  #' and cutoff decided on some ad hoc testing.
  if(is.null(sel)) sel <- names(x)
  if(x[, .N] < 200000L){
    return(
      x[
        do.call(
          pmin,
          Map(`%in%`, x[, .SD, .SDcols = names(sub)], sub)
        ) == 1L,
        .SD,
        .SDcols = sel,
        nomatch = nomatch
      ]
    )
  } else {
    return(
      x[
        do.call(CJ, sub),
        .SD,
        .SDcols = sel,
        on = names(sub),
        nomatch = nomatch
      ]
    )
  }
}

Here is the code used to generate the graph if anyone is curious:

require(data.table)
require(ggplot)
require(microbenchmark)
require(scales)

subsel <- function(x, sub, nomatch = NULL, sel = list()){
  if(length(sel) == 0) sel <- names(x)
  return(
    x[
      do.call(CJ, sub),
      .SD,
      .SDcols = sel,
      on = names(sub),
      nomatch = nomatch
    ]
  )
}

subsel2 <- function(x, sub, nomatch = NULL, sel = list()){
  if(length(sel) == 0) sel <- names(x)
  return(
    x[
      do.call(
        pmin,
        Map(`%in%`, x[, .SD, .SDcols = names(sub)], sub)
      ) == 1L,
      .SD,
      .SDcols = sel,
      nomatch = nomatch
    ]
  )
}

ll <- list(
  a = letters[1:10],
  b = 1:10,
  c = letters[1:10],
  d = 1:10
)

times <- rbindlist(
  lapply(
    seq(from = 100000, to = 1000000, by = 25000),
    function(y){
      dat <- data.table(
        a = sample(letters, y, replace = T),
        b = sample.int(100, y, replace = T),
        c = sample(letters, y, replace = T),
        d = sample.int(100, y, replace = T)
      )
      return(
        rbindlist(
          lapply(
            2:4,
            function(x){
              return(
                setDT(
                  microbenchmark(
                    subsel(dat, sub = head(ll, x), sel = letters[2:4]),
                    subsel2(dat, sub = head(ll, x), sel = letters[2:4])
                  )
                )[, fields := x]
              )
            }
          )
        )[, size := y]
      )
    }
  )
)

times[
  ,
  expr2 := unlist(
    lapply(
      as.character(expr),
      function(x) unlist(strsplit(x, '(', fixed = T))[1]
    )
  )
]
times[
  ,
  expr2 := factor(
    expr2,
    levels = c('subsel', 'subsel2'),
    labels = c('CJ', 'pmin + Map + %in%')
  )
]

ggplot(times, aes(size, time, group = expr2, color = expr2)) +
  geom_smooth() +
  facet_grid(factor(fields) ~ .) +
  scale_y_continuous(labels = number_format(scale = 1e-6)) +
  labs(
    title = 'Execution Time by Fields to Subset on',
    x = 'Dataset Size',
    y = 'Time (Milliseconds)',
    color = 'Function'
  )


来源:https://stackoverflow.com/questions/55728200/subsetting-a-data-table-based-on-a-named-list

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!