Faster weighted sampling without replacement

前端 未结 3 682
臣服心动
臣服心动 2020-12-01 01:45

This question led to a new R package: wrswoR

R\'s default sampling without replacement using sample.int se

3条回答
  •  遥遥无期
    2020-12-01 02:15

    I decided to dig down into some of the comments and found the Efraimidis & Spirakis paper to be fascinating (thanks to @Hemmo for finding the reference). The general idea in the paper is this: create a key by generating a random uniform number and raising it to the power of one over the weight for each item. Then, you simply take the highest key values as your sample. This works out brilliantly!

    weighted_Random_Sample <- function(
        .data,
        .weights,
        .n
        ){
    
        key <- runif(length(.data)) ^ (1 / .weights)
        return(.data[order(key, decreasing=TRUE)][1:.n])
    }
    

    If you set '.n' to be the length of '.data' (which should always be the length of '.weights'), this is actually a weighted reservoir permutation, but the method works well for both sampling and permutation.

    Update: I should probably mention that the above function expects the weights to be greater than zero. Otherwise key <- runif(length(.data)) ^ (1 / .weights) won't be ordered properly.


    Just for kicks, I also used the test scenario in the OP to compare both functions.

    set.seed(1)
    
    times_WRS <- ldply(
    1:7,
    function(i) {
        n <- 1024 * (2 ** i)
        p <- runif(2 * n)
        n_Set <- 1:(2 * n)
        data.frame(
          n=n,
          user=system.time(weighted_Random_Sample(n_Set, p, n), gcFirst=T)['user.self'])
      },
      .progress='text'
    )
    
    sample.int.test <- function(n, p) {
    sample.int(2 * n, n, replace=F, prob=p); NULL }
    
    times_sample.int <- ldply(
      1:7,
      function(i) {
        n <- 1024 * (2 ** i)
        p <- runif(2 * n)
        data.frame(
          n=n,
          user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
      },
      .progress='text'
    )
    
    times_WRS$group <- "WRS"
    times_sample.int$group <- "sample.int"
    library(ggplot2)
    
    ggplot(rbind(times_WRS, times_sample.int) , aes(x=n, y=user/n, col=group)) + geom_point() + scale_x_log10() +  ylab('Time per unit (s)')
    

    And here are the times:

    times_WRS
    #        n user
    # 1   2048 0.00
    # 2   4096 0.01
    # 3   8192 0.00
    # 4  16384 0.01
    # 5  32768 0.03
    # 6  65536 0.06
    # 7 131072 0.16
    
    times_sample.int
    #        n  user
    # 1   2048  0.02
    # 2   4096  0.05
    # 3   8192  0.14
    # 4  16384  0.58
    # 5  32768  2.33
    # 6  65536  9.23
    # 7 131072 37.79
    

    performance comparison

提交回复
热议问题