numpy matrix, setting 0 to values by sorting each row

后端 未结 2 1833
青春惊慌失措
青春惊慌失措 2020-12-11 23:29

I have a matrix, with many rows, and 8 columns. Each cell represents a probability for the current row to belong to 1 of the 8 classes. I would like to keep only the 2 highe

相关标签:
2条回答
  • 2020-12-12 00:04

    This should work, however, it alters a. Is this what you want? Is it essential to avoid loops?

    sorted = np.sort(a, axis=1)
    
    for idx, row in enumerate(a):
        row[row < sorted[idx,-2]] = 0    
    

    Or you could do this:

    a[a < sorted[:,None,-2]] = 0
    
    0 讨论(0)
  • 2020-12-12 00:16

    Here's one vectorized approach with np.argpartition -

    m,n = a.shape
    a[np.arange(m)[:,None],np.argpartition(a,n-2,axis=1)[:,:-2]] = 0
    

    Sample run -

    In [570]: a
    Out[570]: 
    array([[ 0.94791114,  0.48438182,  0.54574317,  0.45481231,  0.94013836],
           [ 0.03861196,  0.99047316,  0.7897759 ,  0.38863967,  0.93659426],
           [ 0.49436676,  0.93762758,  0.33694977,  0.45701655,  0.73078113],
           [ 0.21240062,  0.85141765,  0.00815352,  0.52517721,  0.49752736]])
    
    In [571]: m,n = a.shape
         ...: a[np.arange(m)[:,None],np.argpartition(a,n-2,axis=1)[:,:-2]] = 0
         ...: 
    
    In [572]: a
    Out[572]: 
    array([[ 0.94791114,  0.        ,  0.        ,  0.        ,  0.94013836],
           [ 0.        ,  0.99047316,  0.        ,  0.        ,  0.93659426],
           [ 0.        ,  0.93762758,  0.        ,  0.        ,  0.73078113],
           [ 0.        ,  0.85141765,  0.        ,  0.52517721,  0.        ]])
    
    0 讨论(0)
提交回复
热议问题