How do I get indices of N maximum values in a NumPy array?

后端 未结 18 1439
长情又很酷
长情又很酷 2020-11-22 04:25

NumPy proposes a way to get the index of the maximum value of an array via np.argmax.

I would like a similar thing, but returning the indexes of the

18条回答
  •  执念已碎
    2020-11-22 04:47

    Newer NumPy versions (1.8 and up) have a function called argpartition for this. To get the indices of the four largest elements, do

    >>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
    >>> a
    array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
    >>> ind = np.argpartition(a, -4)[-4:]
    >>> ind
    array([1, 5, 8, 0])
    >>> a[ind]
    array([4, 9, 6, 9])
    

    Unlike argsort, this function runs in linear time in the worst case, but the returned indices are not sorted, as can be seen from the result of evaluating a[ind]. If you need that too, sort them afterwards:

    >>> ind[np.argsort(a[ind])]
    array([1, 8, 5, 0])
    

    To get the top-k elements in sorted order in this way takes O(n + k log k) time.

提交回复
热议问题