N largest values in each row of ndarray

后端 未结 2 1164
独厮守ぢ
独厮守ぢ 2021-01-12 20:31

I have an ndarray where each row is a separate histogram. For each row, I wish to find the top N values.

I am aware of a solution for the global top N values (A fas

2条回答
  •  耶瑟儿~
    2021-01-12 21:19

    You can use np.partition in the same way as the question you linked: the sorting is already along the last axis:

    In [2]: a = np.array([[ 5,  4,  3,  2,  1],
                   [10,  9,  8,  7,  6]])
    In [3]: b = np.partition(a, -3)    # top 3 values from each row
    In [4]: b[:,-3:]
    Out[4]: 
    array([[ 3,  4,  5],
           [ 8,  9, 10]])
    

提交回复
热议问题