How do I get indices of N maximum values in a NumPy array?

后端 未结 18 1516
长情又很酷
长情又很酷 2020-11-22 04:25

NumPy proposes a way to get the index of the maximum value of an array via np.argmax.

I would like a similar thing, but returning the indexes of the

18条回答
  •  梦谈多话
    2020-11-22 04:50

    Use:

    def max_indices(arr, k):
        '''
        Returns the indices of the k first largest elements of arr
        (in descending order in values)
        '''
        assert k <= arr.size, 'k should be smaller or equal to the array size'
        arr_ = arr.astype(float)  # make a copy of arr
        max_idxs = []
        for _ in range(k):
            max_element = np.max(arr_)
            if np.isinf(max_element):
                break
            else:
                idx = np.where(arr_ == max_element)
            max_idxs.append(idx)
            arr_[idx] = -np.inf
        return max_idxs
    

    It also works with 2D arrays. For example,

    In [0]: A = np.array([[ 0.51845014,  0.72528114],
                         [ 0.88421561,  0.18798661],
                         [ 0.89832036,  0.19448609],
                         [ 0.89832036,  0.19448609]])
    In [1]: max_indices(A, 8)
    Out[1]:
        [(array([2, 3], dtype=int64), array([0, 0], dtype=int64)),
         (array([1], dtype=int64), array([0], dtype=int64)),
         (array([0], dtype=int64), array([1], dtype=int64)),
         (array([0], dtype=int64), array([0], dtype=int64)),
         (array([2, 3], dtype=int64), array([1, 1], dtype=int64)),
         (array([1], dtype=int64), array([1], dtype=int64))]
    
    In [2]: A[max_indices(A, 8)[0]][0]
    Out[2]: array([ 0.89832036])
    

提交回复
热议问题