Top N values in 2d array with duplicates to mask

折月煮酒 提交于 2020-05-31 04:08:19

问题


I have 2d numpy array:

arr = np.array([[0.1, 0.1, 0.3, 0.4, 0.5], 
                [0.06, 0.1, 0.1, 0.1, 0.01], 
                [0.24, 0.24, 0.24, 0.24, 0.24], 
                [0.2, 0.25, 0.3, 0.12, 0.02]])
print (arr)
[[0.1  0.1  0.3  0.4  0.5 ]
 [0.06 0.1  0.1  0.1  0.01]
 [0.24 0.24 0.24 0.24 0.24]
 [0.2  0.25 0.3  0.12 0.02]]

I want filter top N values, so I use argsort:

N = 2
arr1 = np.argsort(-arr, kind='mergesort') < N
print (arr1)
[[False False False  True  True]
 [ True False False  True False] <- first top 2 are duplicates
 [ True  True False False False]
 [False  True  True False False]]

It working nice, at least not top duplicates, like for row 2.

Expected output:

print (arr1)
[[False False False  True  True]
 [False  True  True False False]
 [ True  True False False False]
 [False  True  True False False]]

Is possible some faster way for handle it?


回答1:


Slice to get those top N indices and use those to create the final mask -

idx = np.argsort(-arr, kind='mergesort')[:,:N]
mask = np.zeros(arr.shape, dtype=bool)
np.put_along_axis(mask, idx, True, axis=-1)


来源:https://stackoverflow.com/questions/61517878/top-n-values-in-2d-array-with-duplicates-to-mask

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!