Get N maximum values and indices along an axis in a NumPy array

后端 未结 3 1777
谎友^
谎友^ 2020-12-11 10:44

I think this is an easy question for experienced numpy users.

I have a score matrix. The raw index corresponds to samples and column index corresponds to items. For

3条回答
  •  被撕碎了的回忆
    2020-12-11 11:16

    Here's an approach using np.argpartition -

    idx = np.argpartition(a,range(M))[:,:-M-1:-1] # topM_ind
    out = a[np.arange(a.shape[0])[:,None],idx]    # topM_score
    

    Sample run -

    In [343]: a
    Out[343]: 
    array([[ 1. ,  0.3,  0.4],
           [ 0.2,  0.6,  0.8],
           [ 0.1,  0.3,  0.5]])
    
    In [344]: M = 2
    
    In [345]: idx = np.argpartition(a,range(M))[:,:-M-1:-1]
    
    In [346]: idx
    Out[346]: 
    array([[0, 2],
           [2, 1],
           [2, 1]])
    
    In [347]: a[np.arange(a.shape[0])[:,None],idx]
    Out[347]: 
    array([[ 1. ,  0.4],
           [ 0.8,  0.6],
           [ 0.5,  0.3]])
    

    Alternatively, possibly slower, but a bit shorter code to get idx would be with np.argsort -

    idx = a.argsort(1)[:,:-M-1:-1]
    

    Here's a post containing some runtime test that compares np.argsort and np.argpartition on a similar problem.

提交回复
热议问题