Comparing NumPy arrays so that NaNs compare equal

后端 未结 4 2158
南旧
南旧 2020-12-15 18:16

Is there an idiomatic way to compare two NumPy arrays that would treat NaNs as being equal to each other (but not equal to anything other than a NaN).

For e

4条回答
  •  不知归路
    2020-12-15 18:48

    Disclaimer: I don't recommend this for regular use, and I wouldn't use it myself, but I could imagine rare circumstances under which it might be useful.

    If the arrays have the same shape and dtype, you could consider using the low-level memoryview:

    >>> import numpy as np
    >>> 
    >>> a0 = np.array([1.0, np.NAN, 2.0])
    >>> ac = a0 * (1+0j)
    >>> b0 = np.array([1.0, np.NAN, 2.0])
    >>> b1 = np.array([1.0, np.NAN, 2.0, np.NAN])
    >>> c0 = np.array([1.0, 0.0, 2.0])
    >>> 
    >>> memoryview(a0)
    
    >>> memoryview(a0) == memoryview(a0)
    True
    >>> memoryview(a0) == memoryview(ac) # equal but different dtype
    False
    >>> memoryview(a0) == memoryview(b0) # hooray!
    True
    >>> memoryview(a0) == memoryview(b1)
    False
    >>> memoryview(a0) == memoryview(c0)
    False
    

    But beware of subtle problems like this:

    >>> zp = np.array([0.0])
    >>> zm = -1*zp
    >>> zp
    array([ 0.])
    >>> zm
    array([-0.])
    >>> zp == zm
    array([ True], dtype=bool)
    >>> memoryview(zp) == memoryview(zm)
    False
    

    which happens because the binary representations differ even though they compare equal (they have to, of course: that's how it knows to print the negative sign)

    >>> memoryview(zp)[0]
    '\x00\x00\x00\x00\x00\x00\x00\x00'
    >>> memoryview(zm)[0]
    '\x00\x00\x00\x00\x00\x00\x00\x80'
    

    On the bright side, it short-circuits the way you might hope it would:

    In [47]: a0 = np.arange(10**7)*1.0
    In [48]: a0[-1] = np.NAN    
    In [49]: b0 = np.arange(10**7)*1.0    
    In [50]: b0[-1] = np.NAN     
    In [51]: timeit memoryview(a0) == memoryview(b0)
    10 loops, best of 3: 31.7 ms per loop
    In [52]: c0 = np.arange(10**7)*1.0    
    In [53]: c0[0] = np.NAN   
    In [54]: d0 = np.arange(10**7)*1.0    
    In [55]: d0[0] = 0.0    
    In [56]: timeit memoryview(c0) == memoryview(d0)
    100000 loops, best of 3: 2.51 us per loop
    

    and for comparison:

    In [57]: timeit np.all((a0 == b0) | (np.isnan(a0) & np.isnan(b0)))
    1 loops, best of 3: 296 ms per loop
    In [58]: timeit np.all((c0 == d0) | (np.isnan(c0) & np.isnan(d0)))
    1 loops, best of 3: 284 ms per loop
    

提交回复
热议问题