How to get a list of all indices of repeated elements in a numpy array?

后端 未结 6 755
你的背包
你的背包 2020-12-03 03:37

I\'m trying to get the index of all repeated elements in a numpy array, but the solution I found for the moment is REALLY inefficient for a large (>20000 elements) input

6条回答
  •  臣服心动
    2020-12-03 04:20

    A vectorized solution with numpy, on the magic of unique().

    import numpy as np
    
    # create a test array
    records_array = np.array([1, 2, 3, 1, 1, 3, 4, 3, 2])
    
    # creates an array of indices, sorted by unique element
    idx_sort = np.argsort(records_array)
    
    # sorts records array so all unique elements are together 
    sorted_records_array = records_array[idx_sort]
    
    # returns the unique values, the index of the first occurrence of a value, and the count for each element
    vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)
    
    # splits the indices into separate arrays
    res = np.split(idx_sort, idx_start[1:])
    
    #filter them with respect to their size, keeping only items occurring more than once
    vals = vals[count > 1]
    res = filter(lambda x: x.size > 1, res)
    

    The following code was the original answer, which required a bit more memory, using numpy broadcasting and calling unique twice:

    records_array = array([1, 2, 3, 1, 1, 3, 4, 3, 2])
    vals, inverse, count = unique(records_array, return_inverse=True,
                                  return_counts=True)
    
    idx_vals_repeated = where(count > 1)[0]
    vals_repeated = vals[idx_vals_repeated]
    
    rows, cols = where(inverse == idx_vals_repeated[:, newaxis])
    _, inverse_rows = unique(rows, return_index=True)
    res = split(cols, inverse_rows[1:])
    

    with as expected res = [array([0, 3, 4]), array([1, 8]), array([2, 5, 7])]

提交回复
热议问题