How to get a list of all indices of repeated elements in a numpy array?

后端 未结 6 757
你的背包
你的背包 2020-12-03 03:37

I\'m trying to get the index of all repeated elements in a numpy array, but the solution I found for the moment is REALLY inefficient for a large (>20000 elements) input

6条回答
  •  青春惊慌失措
    2020-12-03 04:11

    You could do something along the lines of:

    1. add original index ref so [[1,0],[2,1],[3,2],[1,3],[1,4]...
    2. sort on [:,0]
    3. use np.where(ra[1:,0] != ra[:-1,0])
    4. use the list of indexes from above to construct your final list of lists
    

    EDIT - OK so after my quick reply I've been away for a while and I see I've been voted down which is fair enough as numpy.argsort() is a much better way than my suggestion. I did vote up the numpy.unique() answer as this is an interesting feature. However if you use timeit you will find that

    idx_start = np.where(sorted_records_array[:-1] != sorted_records_array[1:])[0] + 1
    res = np.split(idx_sort, idx_start)
    

    is marginally faster than

    vals, idx_start = np.unique(sorted_records_array, return_index=True)
    res = np.split(idx_sort, idx_start[1:])
    

    Further edit follow question by @Nicolas

    I'm not sure you can. It would be possible to get two arrays of indices in corresponding with the break points but you can't break different 'lines' of the array up into different sized pieces using np.split so

    a = np.array([[4,27,42,12, 4 .. 240, 12], [3,65,23...] etc])
    idx = np.argsort(a, axis=1)
    sorted_a = np.diagonal(a[:, idx[:]]).T
    idx_start = np.where(sorted_a[:,:-1] != sorted_a[:,1:])
    
    # idx_start => (array([0,0,0,..1,1,..]), array([1,4,6,7..99,0,4,5]))
    

    but that might be good enough depending on what you want to do with the information.

提交回复
热议问题