Fast random weighted selection across all rows of a stochastic matrix

前端 未结 2 1456
遥遥无期
遥遥无期 2020-12-09 18:13

numpy.random.choice allows for weighted selection from a vector, i.e.

arr = numpy.array([1, 2, 3])
weight         


        
相关标签:
2条回答
  • 2020-12-09 18:39

    Here's a fully vectorized version that's pretty fast:

    def vectorized(prob_matrix, items):
        s = prob_matrix.cumsum(axis=0)
        r = np.random.rand(prob_matrix.shape[1])
        k = (s < r).sum(axis=0)
        return items[k]
    

    In theory, searchsorted is the right function to use for looking up the random value in the cumulatively summed probabilities, but with m being relatively small, k = (s < r).sum(axis=0) ends up being much faster. Its time complexity is O(m), while the searchsorted method is O(log(m)), but that will only matter for much larger m. Also, cumsum is O(m), so both vectorized and @perimosocordiae's improved are O(m). (If your m is, in fact, much larger, you'll have to run some tests to see how large m can be before this method is slower.)

    Here's the timing I get with m = 10 and n = 10000 (using the functions original and improved from @perimosocordiae's answer):

    In [115]: %timeit original(prob_matrix, items)
    1 loops, best of 3: 270 ms per loop
    
    In [116]: %timeit improved(prob_matrix, items)
    10 loops, best of 3: 24.9 ms per loop
    
    In [117]: %timeit vectorized(prob_matrix, items)
    1000 loops, best of 3: 1 ms per loop
    

    The full script where the functions are defined is:

    import numpy as np
    
    
    def improved(prob_matrix, items):
        # transpose here for better data locality later
        cdf = np.cumsum(prob_matrix.T, axis=1)
        # random numbers are expensive, so we'll get all of them at once
        ridx = np.random.random(size=n)
        # the one loop we can't avoid, made as simple as possible
        idx = np.zeros(n, dtype=int)
        for i, r in enumerate(ridx):
            idx[i] = np.searchsorted(cdf[i], r)
        # fancy indexing all at once is faster than indexing in a loop
        return items[idx]
    
    
    def original(prob_matrix, items):
        choices = np.zeros((n,))
        # This is slow, because of the loop in Python
        for i in range(n):
            choices[i] = np.random.choice(items, p=prob_matrix[:,i])
        return choices
    
    
    def vectorized(prob_matrix, items):
        s = prob_matrix.cumsum(axis=0)
        r = np.random.rand(prob_matrix.shape[1])
        k = (s < r).sum(axis=0)
        return items[k]
    
    
    m = 10
    n = 10000 # Or some very large number
    
    items = np.arange(m)
    prob_weights = np.random.rand(m, n)
    prob_matrix = prob_weights / prob_weights.sum(axis=0, keepdims=True)
    
    0 讨论(0)
  • 2020-12-09 18:43

    I don't think it's possible to completely vectorize this, but you can still get a decent speedup by vectorizing as much as you can. Here's what I came up with:

    def improved(prob_matrix, items):
        # transpose here for better data locality later
        cdf = np.cumsum(prob_matrix.T, axis=1)
        # random numbers are expensive, so we'll get all of them at once
        ridx = np.random.random(size=n)
        # the one loop we can't avoid, made as simple as possible
        idx = np.zeros(n, dtype=int)
        for i, r in enumerate(ridx):
          idx[i] = np.searchsorted(cdf[i], r)
        # fancy indexing all at once is faster than indexing in a loop
        return items[idx]
    

    Testing against the version in the question:

    def original(prob_matrix, items):
        choices = np.zeros((n,))
        # This is slow, because of the loop in Python
        for i in range(n):
            choices[i] = np.random.choice(items, p=prob_matrix[:,i])
        return choices
    

    Here's the speedup (using the setup code given in the question):

    In [45]: %timeit original(prob_matrix, items)
    100 loops, best of 3: 2.86 ms per loop
    
    In [46]: %timeit improved(prob_matrix, items)
    The slowest run took 4.15 times longer than the fastest. This could mean that an intermediate result is being cached
    10000 loops, best of 3: 157 µs per loop
    

    I'm not sure why there's a big discrepancy in timings for my version, but even the slowest run (~650 µs) is still almost 5x faster.

    0 讨论(0)
提交回复
热议问题