Cannot understand numpy argpartition output

前端 未结 3 1324
孤独总比滥情好
孤独总比滥情好 2020-12-05 14:05

I am trying to use arpgpartition from numpy, but it seems there is something going wrong and I cannot seem to figure it out. Here is what\'s happening:

These are fir

3条回答
  •  攒了一身酷
    2020-12-05 14:29

    We need to use list of indices that are to be kept in sorted order instead of feeding the kth param as a scalar. Thus, to maintain the sorted nature across the first 5 elements, instead of np.argpartition(a,5)[:5], simply do -

    np.argpartition(a,range(5))[:5]
    

    Here's a sample run to make things clear -

    In [84]: a = np.random.rand(10)
    
    In [85]: a
    Out[85]: 
    array([ 0.85017222,  0.19406266,  0.7879974 ,  0.40444978,  0.46057793,
            0.51428578,  0.03419694,  0.47708   ,  0.73924536,  0.14437159])
    
    In [86]: a[np.argpartition(a,5)[:5]]
    Out[86]: array([ 0.19406266,  0.14437159,  0.03419694,  0.40444978,  0.46057793])
    
    In [87]: a[np.argpartition(a,range(5))[:5]]
    Out[87]: array([ 0.03419694,  0.14437159,  0.19406266,  0.40444978,  0.46057793])
    

    Please note that argpartition makes sense on performance aspect, if we are looking to get sorted indices for a small subset of elements, let's say k number of elems which is a small fraction of the total number of elems.

    Let's use a bigger dataset and try to get sorted indices for all elems to make the above mentioned point clear -

    In [51]: a = np.random.rand(10000)*100
    
    In [52]: %timeit np.argpartition(a,range(a.size-1))[:5]
    10 loops, best of 3: 105 ms per loop
    
    In [53]: %timeit a.argsort()
    1000 loops, best of 3: 893 µs per loop
    

    Thus, to sort all elems, np.argpartition isn't the way to go.

    Now, let's say I want to get sorted indices for only the first 5 elems with that big dataset and also keep the order for those -

    In [68]: a = np.random.rand(10000)*100
    
    In [69]: np.argpartition(a,range(5))[:5]
    Out[69]: array([1647,  942, 2167, 1371, 2571])
    
    In [70]: a.argsort()[:5]
    Out[70]: array([1647,  942, 2167, 1371, 2571])
    
    In [71]: %timeit np.argpartition(a,range(5))[:5]
    10000 loops, best of 3: 112 µs per loop
    
    In [72]: %timeit a.argsort()[:5]
    1000 loops, best of 3: 888 µs per loop
    

    Very useful here!

提交回复
热议问题