Get N maximum values and indices along an axis in a NumPy array

后端 未结 3 1776
谎友^
谎友^ 2020-12-11 10:44

I think this is an easy question for experienced numpy users.

I have a score matrix. The raw index corresponds to samples and column index corresponds to items. For

3条回答
  •  甜味超标
    2020-12-11 11:28

    I'd use argsort():

    top2_ind = score_matrix.argsort()[:,::-1][:,:2]
    

    That is, produce an array which contains the indices which would sort score_matrix:

    array([[1, 2, 0],
           [0, 1, 2],
           [0, 1, 2]])
    

    Then reverse the columns with ::-1, then take the first two columns with :2:

    array([[0, 2],
           [2, 1],
           [2, 1]])
    

    Then similar but with regular np.sort() to get the values:

    top2_score = np.sort(score_matrix)[:,::-1][:,:2]
    

    Which following the same mechanics as above, gives you:

    array([[ 1. ,  0.4],
           [ 0.8,  0.6],
           [ 0.5,  0.3]])
    

提交回复
热议问题