Weighted random sample in python

前端 未结 6 1789
余生分开走
余生分开走 2020-11-27 06:19

I\'m looking for a reasonable definition of a function weighted_sample that does not return just one random index for a list of given weights (which would be so

6条回答
  •  借酒劲吻你
    2020-11-27 06:56

    From your code: ..

    weight_sample_indexes = lambda weights, k: random.sample([val 
            for val, cnt in enumerate(weights) for i in range(cnt)], k)
    

    .. I assume that weights are positive integers and by "without replacement" you mean without replacement for the unraveled sequence.

    Here's a solution based on random.sample and O(log n) __getitem__:

    import bisect
    import random
    from collections import Counter, Sequence
    
    def weighted_sample(population, weights, k):
        return random.sample(WeightedPopulation(population, weights), k)
    
    class WeightedPopulation(Sequence):
        def __init__(self, population, weights):
            assert len(population) == len(weights) > 0
            self.population = population
            self.cumweights = []
            cumsum = 0 # compute cumulative weight
            for w in weights:
                cumsum += w   
                self.cumweights.append(cumsum)  
        def __len__(self):
            return self.cumweights[-1]
        def __getitem__(self, i):
            if not 0 <= i < len(self):
                raise IndexError(i)
            return self.population[bisect.bisect(self.cumweights, i)]
    

    Example

    total = Counter()
    for _ in range(1000):
        sample = weighted_sample("abc", [1,10,2], 5)
        total.update(sample)
    print(sample)
    print("Frequences %s" % (dict(Counter(sample)),))
    
    # Check that values are sane
    print("Total " + ', '.join("%s: %.0f" % (val, count * 1.0 / min(total.values()))
                               for val, count in total.most_common()))
    

    Output

    ['b', 'b', 'b', 'c', 'c']
    Frequences {'c': 2, 'b': 3}
    Total b: 10, c: 2, a: 1
    

提交回复
热议问题