Sum of subvectors of a vector in R

后端 未结 6 733
悲哀的现实
悲哀的现实 2021-01-04 19:01

Given a vector x of length k, I would like to obtain a k by k matrix X where X[i,j] is the sum of x[i] + ... + x[j]. The

6条回答
  •  我在风中等你
    2021-01-04 19:19

    Here's another approach which seems to be significantly faster than OP's for loop (by factor ~30) and faster than the other answers currently present (by factor >=18):

    n <- 5
    x <- 1:5
    z <- lapply(1:n, function(i) cumsum(x[i:n]))
    m <- mapply(function(y, l) c(rep(NA, n-l), y), z, lengths(z))
    m[upper.tri(m)] <- t(m)[upper.tri(m)]
    m
    
    #     [,1] [,2] [,3] [,4] [,5]
    #[1,]    1    3    6   10   15
    #[2,]    3    2    5    9   14
    #[3,]    6    5    3    7   12
    #[4,]   10    9    7    4    9
    #[5,]   15   14   12    9    5
    

    Benchmarks (scroll down for results)

    library(microbenchmark)
    n <- 100
    x <- 1:n
    
    f1 <- function() {
      X <- matrix(0,n,n)
      for(i in 1:n) {
        for(j in 1:n) {
          X[i,j] <- sum(x[i:j])
        }
      }
      X
    }
    
    f2 <- function() {
      mySum <- function(i,j) sum(x[i:j])
      outer(1:n, 1:n, Vectorize(mySum))
    }
    
    f3 <- function() {
      matrix(apply(expand.grid(1:n, 1:n), 1, function(y) sum(x[y[2]:y[1]])), n, n)
    }
    
    f4 <- function() {
      z <- lapply(1:n, function(i) cumsum(x[i:n]))
      m <- mapply(function(y, l) c(rep(NA, n-l), y), z, lengths(z))
      m[upper.tri(m)] <- t(m)[upper.tri(m)]
      m
    }
    
    f5 <- function() {
      X <- diag(x)
      for(i in 1:(n-1)) {
        for(j in 1:(n-i)){
          X[j+i,j] <- X[j,j+i] <- X[j+i,j+i] + X[j+i-1,j]
        }  
      }
      X
    }
    
    microbenchmark(f1(), f2(), f3(), f4(), f5(), times = 25L, unit = "relative")
    #Unit: relative
    # expr      min       lq     mean   median       uq      max neval
    # f1() 29.90113 29.01193 30.82411 31.15412 32.51668 35.93552    25
    # f2() 29.46394 30.93101 31.79682 31.88397 34.52489 28.74846    25
    # f3() 56.05807 53.82641 53.63785 55.36704 55.62439 45.94875    25
    # f4()  1.00000  1.00000  1.00000  1.00000  1.00000  1.00000    25
    # f5() 16.30136 17.46371 18.86259 17.87850 21.19914 23.68106    25
    
    all.equal(f1(), f2())
    #[1] TRUE
    all.equal(f1(), f3())
    #[1] TRUE
    all.equal(f1(), f4())
    #[1] TRUE
    all.equal(f1(), f5())
    #[1] TRUE
    

    Updated with the edited function by Neal Fultz.

提交回复
热议问题