what is the most efficient way to find the position of the first np.nan value?

前端 未结 4 1597
南旧
南旧 2020-12-03 21:46

consider the array a

a = np.array([3, 3, np.nan, 3, 3, np.nan])

I could do

np.isnan(a).argmax()
4条回答
  •  不思量自难忘°
    2020-12-03 21:58

    It might also be worth to look into numba.jit; without it, the vectorized version will likely beat a straight-forward pure-Python search in most scenarios, but after compiling the code, the ordinary search will take the lead, at least in my testing:

    In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
    
    In [70]: %paste
    import numba
    
    def naive(a):
            for i in range(len(a)):
                    if np.isnan(a[i]):
                            return i
    
    def short(a):
            return np.isnan(a).argmax()
    
    @numba.jit
    def naive_jit(a):
            for i in range(len(a)):
                    if np.isnan(a[i]):
                            return i
    
    @numba.jit
    def short_jit(a):
            return np.isnan(a).argmax()
    ## -- End pasted text --
    
    In [71]: %timeit naive(a)
    100 loops, best of 3: 7.22 ms per loop
    
    In [72]: %timeit short(a)
    The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
    10000 loops, best of 3: 37.7 µs per loop
    
    In [73]: %timeit naive_jit(a)
    The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
    100000 loops, best of 3: 6.79 µs per loop
    
    In [74]: %timeit short_jit(a)
    The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
    10000 loops, best of 3: 144 µs per loop
    

    Edit: As pointed out by @hpaulj in their answer, numpy actually ships with an optimized short-circuited search whose performance is comparable with the JITted search above:

    In [26]: %paste
    def plain(a):
            return a.argmax()
    
    @numba.jit
    def plain_jit(a):
            return a.argmax()
    ## -- End pasted text --
    
    In [35]: %timeit naive(a)
    100 loops, best of 3: 7.13 ms per loop
    
    In [36]: %timeit plain(a)
    The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
    100000 loops, best of 3: 7.04 µs per loop
    
    In [37]: %timeit naive_jit(a)
    100000 loops, best of 3: 6.91 µs per loop
    
    In [38]: %timeit plain_jit(a)
    10000 loops, best of 3: 125 µs per loop
    

提交回复
热议问题