How to broadcast a function over a numpy array, when dtype=object?

谁都会走 提交于 2019-12-12 09:12:17

问题


If I have an array of numerical values, which had to use object pointers instead of values as the data type, due to unequal vector lengths:

In [145]: import numpy as np

In [147]: a = np.array([[1,2],[3,4,5]])

In [148]: a
Out[148]: array([[1, 2], [3, 4, 5]], dtype=object)

In [150]: np.sin(a)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-150-58d97006f018> in <module>()
----> 1 np.sin(a)

In [152]: np.sin(a[0])
Out[152]: array([ 0.84147098,  0.90929743])

How do I broadcast a function over the actual numerical values without having to manually traverse the array?


回答1:


There are a couple of different issues here. First, there's little to be gained by broadcasting over python objects in numpy; you'll probably do better using pure python in this case.

>>> a = np.array([[1, 2, 3], [4, 5, 6]], dtype=object)
>>> b = np.arange(1, 7).reshape(2, 3)
>>> c = [[1, 2, 3], [4, 5, 6]]
>>> %timeit a * 5
100000 loops, best of 3: 4.28 µs per loop
>>> %timeit b * 5
100000 loops, best of 3: 2.08 µs per loop
>>> %timeit [[x * 5 for x in l] for l in c]
1000000 loops, best of 3: 998 ns per loop

Those speeds will scale a bit unevenly but you get the idea.

Second, the problem isn't directly related to broadcasting. numpy will happily broadcast over python lists. The result just isn't what you expect:

>>> a = np.array([[1, 2, 3], [4, 5]], dtype=object)
>>> a * 5
    array([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
       [4, 5, 4, 5, 4, 5, 4, 5, 4, 5]], dtype=object)

numpy allows the objects in the array to define their own versions of whichever operator or function it's broadcasting. In this case, python lists define * as repetition! This holds even for heterogenous arrays; try this: np.array([5, [1, 2]], dtype=object) * 5. The reason sin doesn't broadcast in this case is that python lists don't define sin at all.

You'd probably be better off using a fixed-width array with a mask.

>>> np.ma.array([[1, 2, 3], [4, 5, 6]], mask=[[0, 0, 0], [0, 0, 1]])
    masked_array(data =
 [[1 2 3]
 [4 5 --]],
             mask =
 [[False False False]
 [False False  True]],
       fill_value = 999999)

As you can see, you can "simulate" a ragged array this way, and it will behave just as expected.

>>> a = np.ma.array([[1, 2, 3], [4, 5, 6]], mask=[[0, 0, 0], [0, 0, 1]])
>>> np.sin(a)
    masked_array(data =
 [[0.841470984808 0.909297426826 0.14112000806]
 [-0.756802495308 -0.958924274663 --]],
             mask =
 [[False False False]
 [False False  True]],
       fill_value = 1e+20)

It's worth mentioning a few ways to create masked arrays. In your case, masked_invalid might be useful.

>>> np.ma.masked_invalid([[1, 2, 3], [4, 5, np.NaN]])
masked_array(data =
 [[1.0 2.0 3.0]
 [4.0 5.0 --]],
             mask =
 [[False False False]
 [False False  True]],
       fill_value = 1e+20)

You can also create masked arrays using conditions:

>>> x = np.array([[1, 2, 3], [4, 5, 6]])
>>> np.ma.masked_where(x > 5, x)
masked_array(data =
 [[1 2 3]
 [4 5 --]],
             mask =
 [[False False False]
 [False False  True]],
       fill_value = 999999)

For a full list of variations on these techniques, see here.




回答2:


Like others have suggested, it's best to avoid arrays dtype=object.

Another approach for avoiding that, which surprisingly nobody has mentioned so far, is padding with NaNs, in order to achieve a common shape.

a = np.array([[1,2],[3,4,5]])
maxlen = max(len(x) for x in a)
b = np.array([ x+[np.NaN]*(maxlen-len(x)) for x in a ])
b
=> array([[  1.,   2.,  nan], [  3.,   4.,   5.]])
b.shape
=> (2, 3)
np.sin(b) 
=> array([[ 0.84147098,  0.90929743,         nan],
          [ 0.14112001, -0.7568025 , -0.95892427]])

Of course, handling arrays containing NaNs should be done with care, e.g. you probably want to use nanmax instead of max, etc.



来源:https://stackoverflow.com/questions/23795569/how-to-broadcast-a-function-over-a-numpy-array-when-dtype-object

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!