numpy matrix, setting 0 to values by sorting each row

后端 未结 2 1834
青春惊慌失措
青春惊慌失措 2020-12-11 23:29

I have a matrix, with many rows, and 8 columns. Each cell represents a probability for the current row to belong to 1 of the 8 classes. I would like to keep only the 2 highe

2条回答
  •  轮回少年
    2020-12-12 00:16

    Here's one vectorized approach with np.argpartition -

    m,n = a.shape
    a[np.arange(m)[:,None],np.argpartition(a,n-2,axis=1)[:,:-2]] = 0
    

    Sample run -

    In [570]: a
    Out[570]: 
    array([[ 0.94791114,  0.48438182,  0.54574317,  0.45481231,  0.94013836],
           [ 0.03861196,  0.99047316,  0.7897759 ,  0.38863967,  0.93659426],
           [ 0.49436676,  0.93762758,  0.33694977,  0.45701655,  0.73078113],
           [ 0.21240062,  0.85141765,  0.00815352,  0.52517721,  0.49752736]])
    
    In [571]: m,n = a.shape
         ...: a[np.arange(m)[:,None],np.argpartition(a,n-2,axis=1)[:,:-2]] = 0
         ...: 
    
    In [572]: a
    Out[572]: 
    array([[ 0.94791114,  0.        ,  0.        ,  0.        ,  0.94013836],
           [ 0.        ,  0.99047316,  0.        ,  0.        ,  0.93659426],
           [ 0.        ,  0.93762758,  0.        ,  0.        ,  0.73078113],
           [ 0.        ,  0.85141765,  0.        ,  0.52517721,  0.        ]])
    

提交回复
热议问题