Group-specific calculations involving both row-specific and whole-group elements

怎甘沉沦 提交于 2019-12-14 03:48:17

问题


I am having a little trouble matching the logic of this problem to that of dplyr. Usually if you want to reduce a group to a single number per group, you use summarise, while if you want to calculate a separate number for each line, you use mutate. But what if you want to make a calculation on the group for each row?

In the example below, mloc contains a pointer to pnum, and the goal is to add a new column nm_child which, for each row, counts the number of mloc values within the group that point to (i.e. have the same value as) the row-in-group index in pnum. This would be easy to do with nested loops, or with map if I knew how to iterate 1) for each group, & 2) by each element, & 3) return the map output as a column in the group.

library(tidyverse)

ser    <- c(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2)
pnum   <- c(1:5, 1:6)
mloc   <- c(0, 2, 2, 0, 3, 1, 1, 0, 0, 3, 4)

tb1 <- tibble(ser,pnum,  mloc)
tb2 <- tb1 %>%
group_by(ser) %>%
mutate(nm_child = sum(pnum == mloc))

The above has nm_child always = 1. I see why it does not work, but I don't see why it does do that.

I also tried

mutate(nm_child = count(pnum == mloc))

(which returns

no applicable method for 'groups' applied to an object of class "logical")

and various other things. I did get one thing to work by adding several columns for intermediate values and using a bunch of nested ifelse()s, but it takes more than 20 minutes to run on my nine million rows -- in contrast to, e.g. regression, and most simple dplyr operations, which vary between a few seconds and too quick to notice.

Desired output:

tb2$nm_child = c(0, 2, 1, 0, 0, 2, 0, 1, 1, 0, 0)

回答1:


You can use outer and rowSums

tb1 %>% 
  group_by(ser) %>% 
  mutate(nm_child = rowSums(outer(pnum, mloc, `==`)))

# # A tibble: 11 x 4
# # Groups:   ser [2]
#      ser  pnum  mloc nm_child
#    <dbl> <int> <dbl>    <dbl>
#  1     1     1     0        0
#  2     1     2     2        2
#  3     1     3     2        1
#  4     1     4     0        0
#  5     1     5     3        0
#  6     2     1     1        2
#  7     2     2     1        0
#  8     2     3     0        1
#  9     2     4     0        1
# 10     2     5     3        0
# 11     2     6     4        0

Benchmark with thelatemail's example data

tb1 <- tb1[rep(1:11,5e4),]
tb1$ser <- rep(1:1e5, rep(5:6,5e4))

tb2 <- as.data.table(tb1)

library(microbenchmark)

microbenchmark(
  sapply = {
    tb1 %>% 
      group_by(ser) %>% 
      mutate(
        nm_child = sapply(pnum, function(x) sum(x == mloc))
      )
  },
  join = {
    tb1 %>%
      group_by(ser, mloc) %>%
      summarise(nm_child=n()) %>%
      left_join(tb1, ., by=c("ser"="ser","pnum"="mloc"))
  },
  outer1 = {
    tb1 %>% 
      group_by(ser) %>% 
      mutate(nm_child = rowSums(outer(pnum, mloc, `==`)))
  },
  outer2 = {
    tb1 %>% 
      group_by(ser) %>% 
      mutate(nm_child = colSums(outer(mloc, pnum, `==`)))
  },
  data.table = {
    tb2[tb2[, .N, by=.(ser,mloc)], on=c("ser","pnum"="mloc"), nm_child := N][]
    },
  times = 10)

Benchmark output

# Unit: milliseconds
#        expr       min        lq      mean    median        uq        max neval
#      sapply 8233.5740 8297.7331 8939.9369 8647.5935 8956.3364 10706.3362    10
#        join  889.6682  899.0483  935.7493  908.1441  932.2827  1135.8424    10
#      outer1 4551.0428 4631.1605 5184.9359 4986.7327 5160.0109  7563.4190    10
#      outer2 4495.9134 4552.1169 4763.5954 4723.7783 4893.2190  5198.4556    10
#  data.table  108.7449  115.7866  124.4453  120.6742  125.7591   171.8111    10



回答2:


This is an aggregation by ser + mloc, then a left-join back to the original data. There should be no need to loop over every single value:

tb1 %>%
  group_by(ser, mloc) %>%
  summarise(nm_child=n()) %>%
  left_join(tb1, ., by=c("ser"="ser","pnum"="mloc"))

## A tibble: 11 x 4
#     ser  pnum  mloc nm_child
#   <dbl> <dbl> <dbl>    <int>
# 1  1.00  1.00  0          NA
# 2  1.00  2.00  2.00        2
# 3  1.00  3.00  2.00        1
# 4  1.00  4.00  0          NA
# 5  1.00  5.00  3.00       NA
# 6  2.00  1.00  1.00        2
# 7  2.00  2.00  1.00       NA
# 8  2.00  3.00  0           1
# 9  2.00  4.00  0           1
#10  2.00  5.00  3.00       NA
#11  2.00  6.00  4.00       NA

This will be much more efficient:

# big example
tb1 <- tb1[rep(1:11,5e4),]
tb1$ser <- rep(1:1e5, rep(5:6,5e4))

system.time({
tb1 %>% 
  group_by(ser) %>% 
  mutate(
    nm_child = sapply(pnum, function(x) sum(x == mloc))
  )
})
#   user  system elapsed 
#   8.83    0.06    8.97     

system.time({
tb1 %>%
  group_by(ser, mloc) %>%
  summarise(nm_child=n()) %>%
  left_join(tb1, ., by=c("ser"="ser","pnum"="mloc"))
})
#   user  system elapsed 
#   0.67    0.02    0.69 

In base R logic this would be something like:

tabu <- aggregate(cbind(nm_child=mloc) ~ ser + mloc, tb1, FUN=length)
merge(tb1, tabu, by.x=c("ser","pnum"), by.y=c("ser","mloc"), all.x=TRUE)

And to round it off in data.table, which will be an order of magnitude faster again:

tb1[tb1[, .N, by=.(ser,mloc)], on=c("ser","pnum"="mloc"), nm_child := N]



回答3:


Here's a way using sapply -

tb1 %>% 
  group_by(ser) %>% 
  mutate(
    nm_child = sapply(pnum, function(x) sum(x == mloc))
  )

# A tibble: 11 x 4
# Groups:   ser [2]
     ser  pnum  mloc nm_child
   <dbl> <int> <dbl>    <int>
 1  1.00     1  0           0
 2  1.00     2  2.00        2
 3  1.00     3  2.00        1
 4  1.00     4  0           0
 5  1.00     5  3.00        0
 6  2.00     1  1.00        2
 7  2.00     2  1.00        0
 8  2.00     3  0           1
 9  2.00     4  0           1
10  2.00     5  3.00        0
11  2.00     6  4.00        0

Here's another way, thanks to @RonakShah -

tb1 %>% 
  group_by(ser) %>% 
  mutate(
    nm_child = map_int(pnum, ~sum(. == mloc))
  )

Update: Looking at the benchmarks in other answers, @thelatemail 's answer is certainly the best.



来源:https://stackoverflow.com/questions/56138283/group-specific-calculations-involving-both-row-specific-and-whole-group-elements

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!