I'm working on something that requires fast coroutines and I believe numba could speed up my code.
Here's a silly example: a function that squares its input, and adds to it the number of times its been called.
def make_square_plus_count():
i = 0
def square_plus_count(x):
nonlocal i
i += 1
return x**2 + i
return square_plus_count
You can't even nopython=False
JIT this, presumably due to the nonlocal
keyword.
But you don't need nonlocal
if you use a class instead:
def make_square_plus_count():
@numba.jitclass({'i': numba.uint64})
class State:
def __init__(self):
self.i = 0
state = State()
@numba.jit()
def square_plus_count(x):
state.i += 1
return x**2 + state.i
return square_plus_count
This at least works, but it breaks if you do nopython=True
.
Is there a solution for this that will compile with nopython=True
?
If you're going to use a state-class anyway you could also use methods instead of a closure (should be no-python compiled):
import numba
@numba.jitclass({'i': numba.uint64})
class State(object):
def __init__(self):
self.i = 0
def square_plus_count(self, x):
self.i += 1
return x**2 + self.i
square_with_call_count = State().square_plus_count # using the method
print([square_with_call_count(i) for i in range(10)])
# [1, 3, 7, 13, 21, 31, 43, 57, 73, 91]
However timings show that this is actually slower than a pure python closure implementation. I expect that as long as you don't use nonlocal
numpy-arrays or do operations on arrays in your method (or closure) this will be less efficient!
来源:https://stackoverflow.com/questions/41842656/coroutines-in-numba