Weighted random sample without replacement in python

后端 未结 3 1218
别那么骄傲
别那么骄傲 2020-12-07 01:55

I need to obtain a k-sized sample without replacement from a population, where each member of the population has a associated weight (W).

Numpy\'s rando

3条回答
  •  误落风尘
    2020-12-07 02:17

    Built-in solution

    As suggested by Miriam Farber, you can just use the numpy's builtin solution:

    np.random.choice(vec,size,replace=False, p=P)
    

    Pure python equivalent

    What follows is close to what numpy does internally. It, of course, uses numpy arrays and numpy.random.choices():

    from random import choices
    
    def weighted_sample_without_replacement(population, weights, k=1):
        weights = list(weights)
        positions = range(len(population))
        indices = []
        while True:
            needed = k - len(indices)
            if not needed:
                break
            for i in choices(positions, weights, k=needed):
                if weights[i]:
                    weights[i] = 0.0
                    indices.append(i)
        return [population[i] for i in indices]
    

    Related problem: Selection when elements can be repeated

    This is sometimes called an urn problem. For example, given an urn with 10 red balls, 4 white balls, and 18 green balls, choose nine balls without replacement.

    To do it with numpy, generate the unique selections from the total population count with sample(). Then, bisect the cumulative weights to get the population indices.

    import numpy as np
    from random import sample
    
    population = np.array(['red', 'blue', 'green'])
    counts = np.array([10, 4, 18])
    k = 9
    
    cum_counts = np.add.accumulate(counts)
    total = cum_counts[-1]
    selections = sample(range(total), k=k)
    indices = np.searchsorted(cum_counts, selections, side='right')
    result = population[indices]
    

    To do this without *numpy', the same approach can be implemented with bisect() and accumulate() from the standard library:

    from random import sample
    from bisect import bisect
    from itertools import accumulate
    
    population = ['red', 'blue', 'green']
    weights = [10, 4, 18]
    k = 9
    
    cum_weights = list(accumulate(weights))
    total = cum_weights.pop()
    selections = sample(range(total), k=k)
    indices = [bisect(cum_weights, s) for s in selections]
    result = [population[i] for i in indices]
    

提交回复
热议问题