问题
I need to sum up elements in a 1D numpy
array (below: data
) based on another array with information on class memberships (labels
). I use numba
in the code below to speed it up. However, If I dot not explicitly cast with int()
in the line ret[int(find(labels, g))] += y
, I reveice an error message:
TypeError: unsupported array index type ?int64
Is there a better workaround that explicit casting?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
回答1:
The problem is that find
can either return an int
or None
if it doesn't find anything, thus I think the ?int64
error. To avoid casting, you need to provide an int
return value when find
exits without finding the desired value and then handle it in the caller.
来源:https://stackoverflow.com/questions/39316939/typeerror-when-indexing-numpy-array-using-numba