I have need the N minimum (index) values in a numpy array

大憨熊 提交于 2019-12-03 10:51:35

If you call

arr.argsort()[:3]

It will give you the indices of the 3 smallest elements.

array([0, 2, 1], dtype=int64)

So, for n, you should call

arr.argsort()[:n]
Alex

Since this question was posted, numpy has updated to include a faster way of selecting the smallest elements from an array using argpartition. It was first included in Numpy 1.8.

Using snarly's answer as inspiration, we can quickly find the k=3 smallest elements:

In [1]: import numpy as np

In [2]: arr = np.array([1, 3, 2, 4, 5])

In [3]: k = 3

In [4]: ind = np.argpartition(arr, k)[:k]

In [5]: ind
Out[5]: array([0, 2, 1])

In [6]: arr[ind]
Out[6]: array([1, 2, 3])

This will run in O(n) time because it does not need to do a full sort. If you need your answers sorted (Note: in this case the output array was in sorted order but that is not guaranteed) you can sort the output:

In [7]: sorted(arr[ind])
Out[7]: array([1, 2, 3])

This runs on O(n + k log k) because the sorting takes place on the smaller output list.

I don't guarantee that this will be faster, but a better algorithm would rely on heapq.

import heapq
indices = heapq.nsmallest(10,np.nditer(arr),key=arr.__getitem__)

This should work in approximately O(N) operations whereas using argsort would take O(NlogN) operations. However, the other is pushed into highly optimized C, so it might still perform better. To know for sure, you'd need to run some tests on your actual data.

Just don't reverse the sort results.

In [164]: a = numpy.random.random(20)

In [165]: a
Out[165]: 
array([ 0.63261763,  0.01718228,  0.42679479,  0.04449562,  0.19160089,
        0.29653725,  0.93946388,  0.39915215,  0.56751034,  0.33210873,
        0.17521395,  0.49573607,  0.84587652,  0.73638224,  0.36303797,
        0.2150837 ,  0.51665416,  0.47111993,  0.79984964,  0.89231776])

Sorted:

In [166]: a.argsort()
Out[166]: 
array([ 1,  3, 10,  4, 15,  5,  9, 14,  7,  2, 17, 11, 16,  8,  0, 13, 18,
       12, 19,  6])

First ten:

In [168]: a.argsort()[:10]
Out[168]: array([ 1,  3, 10,  4, 15,  5,  9, 14,  7,  2])
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!