How to get a list of all indices of repeated elements in a numpy array?

后端 未结 6 762
你的背包
你的背包 2020-12-03 03:37

I\'m trying to get the index of all repeated elements in a numpy array, but the solution I found for the moment is REALLY inefficient for a large (>20000 elements) input

6条回答
  •  忘掉有多难
    2020-12-03 04:25

    • The answer is complicated, and highly dependent up the size, and number of unique elements.
    • The following, tests arrays with 2M elements and up to 20k unique elements
    • And tests arrays up to 80k elements with a max of 20k unique elements
      • For arrays under 40k elements, the tests have up to half the unique elements as the size of the array (e.g. 10k elements would have up to 5k unique elements).

    Arrays with 2M Elements

    • np.where is faster than defaultdict for up to about 200 unique elements, but slower than pandas.core.groupby.GroupBy.indices, and np.unique.
    • The solution using pandas, is the fastest solution for large arrays.

    Arrays with up to 80k Elements

    • This is more situational, depending on the size of the array and the number of unique elements.
    • defaultdict is a fast option for arrays to about 2400 elements, especially with a large number of unique elements.
    • For arrays larger than 40k elements, and 20k unique elements, pandas is the fastest option.

    %timeit

    import random
    import numpy
    import pandas as pd
    from collections import defaultdict
    
    def dd(l):
        # default_dict test
        indices = defaultdict(list)
        for i, v in enumerate(l):
            indices[v].append(i)
        return indices
    
    
    def npw(l):
        # np_where test
        return {v: np.where(l == v)[0] for v in np.unique(l)}
    
    
    def uni(records_array):
        # np_unique test
        idx_sort = np.argsort(records_array)
        sorted_records_array = records_array[idx_sort]
        vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)
        res = np.split(idx_sort, idx_start[1:])
        return dict(zip(vals, res))
    
    
    def daf(l):
        # pandas test
        return pd.DataFrame(l).groupby([0]).indices
    
    
    data = defaultdict(list)
    
    for x in range(4, 20000, 100):  # number of unique elements
        # create 2M element list
        random.seed(365)
        a = np.array([random.choice(range(x)) for _ in range(2000000)])
        
        res1 = %timeit -r2 -n1 -q -o dd(a)
        res2 = %timeit -r2 -n1 -q -o npw(a)
        res3 = %timeit -r2 -n1 -q -o uni(a)
        res4 = %timeit -r2 -n1 -q -o daf(a)
        
        data['defaut_dict'].append(res1.average)
        data['np_where'].append(res2.average)
        data['np_unique'].append(res3.average)
        data['pandas'].append(res4.average)
        data['idx'].append(x)
    
    df = pd.DataFrame(data)
    df.set_index('idx', inplace=True)
    
    df.plot(figsize=(12, 5), xlabel='unique samples', ylabel='average time (s)', title='%timeit test: 2 run 1 loop each')
    plt.legend(bbox_to_anchor=(1.0, 1), loc='upper left')
    plt.show()
    

    Tests with 2M elements

    Tests with up to 80k elements

提交回复
热议问题