Sum of most recent values across groups

前端 未结 3 1330
日久生厌
日久生厌 2021-01-31 03:22

For each row of my data I\'d like to compute the sum of most recent value for each group:

dt = data.table(group = c(\'a\',\'b\',\'a\',\         


        
3条回答
  •  不要未来只要你来
    2021-01-31 04:17

    Even simpler logic from @eddi (under comments) reducing the roundabout one shown below:

    dt[, incr := diff(c(0, value)), by = group][, ans := cumsum(incr)]
    

    Not sure how it extends to more groups, but here's on an example data with 3 groups:

    # I hope I got the desired output correctly
    require(data.table)
    dt = data.table(group = c('a','b','c','a','a','b','c','a'),
                    value = c(10, 5, 20, 25, 15, 15, 30, 10),
                    desired = c(10, 15, 35, 50, 40, 50, 60, 55))
    

    Add an rleid:

    dt[, id := rleid(group)]
    

    Extract the last row for each group, id:

    last = dt[, .(value=value[.N]), by=.(group, id)]
    

    last will have unique id. Now the idea is to get the increment for each id, and then join+update back.

    last = last[, incr := value - shift(value, type="lag", fill=0L), by=group
              ][, incr := cumsum(incr)-value][]
    

    Join + update now:

    dt[last, ans := value + i.incr, on="id"][, id := NULL][]
    #    group value desired ans
    # 1:     a    10      10  10
    # 2:     b     5      15  15
    # 3:     c    20      35  35
    # 4:     a    25      50  50
    # 5:     a    15      40  40
    # 6:     b    15      50  50
    # 7:     c    30      60  60
    # 8:     a    10      55  55
    

    I'm not yet sure where/if this breaks.. will look at it carefully now. I wrote it immediately so that there are more eyes on it.


    Comparing on 500 groups with 10,000 rows with David's solution:

    require(data.table)
    set.seed(45L)
    groups = apply(matrix(sample(letters, 500L*10L, TRUE), ncol=10L), 1L, paste, collapse="")
    uniqueN(groups) # 500L
    N = 1e4L
    dt = data.table(group=sample(groups, N, TRUE), value = sample(100L, N, TRUE))
    
    arun <- function(dt) {
    
        dt[, id := rleid(group)]
        last = dt[, .(value=value[.N]), by=.(group, id)]
        last = last[, incr := value - shift(value, type="lag", fill=0L), by=group
                  ][, incr := cumsum(incr)-value][]
        dt[last, ans := value + i.incr, on="id"][, id := NULL][]
        dt$ans
    }
    
    david <- function(dt) {
        dt[, indx := .I]
        res <- dcast(dt, indx ~ group)
        for (j in names(res)[-1L]) 
            set(res, j = j, value = res[!is.na(res[[j]])][res, on = "indx", roll = TRUE][[j]])
        rowSums(as.matrix(res)[, -1], na.rm = TRUE)
    
    }
    
    system.time(ans1 <- arun(dt))  ## 0.024s
    system.time(ans2 <- david(dt)) ## 38.97s 
    identical(ans1, as.integer(ans2))
    # [1] TRUE
    

提交回复
热议问题