问题
I need to do a numerical integration in 6D in python. Because the scipy.integrate.nquad function is slow I am currently trying to speed things up by defining the integrand as a scipy.LowLevelCallable with Numba.
I was able to do this in 1D with the scipy.integrate.quad by replicating the example given here:
import numpy as np
from numba import cfunc
from scipy import integrate
def integrand(t):
return np.exp(-t) / t**2
nb_integrand = cfunc("float64(float64)")(integrand)
# regular integration
%timeit integrate.quad(integrand, 1, np.inf)
10000 loops, best of 3: 128 µs per loop
# integration with compiled function
%timeit integrate.quad(nb_integrand.ctypes, 1, np.inf)
100000 loops, best of 3: 7.08 µs per loop
When I want to do this now with nquad, the nquad documentation says:
If the user desires improved integration performance, then f may be a scipy.LowLevelCallable with one of the signatures:
double func(int n, double *xx) double func(int n, double *xx, void *user_data)
where n is the number of extra parameters and args is an array of doubles of the additional parameters, the xx array contains the coordinates. The user_data is the data contained in the scipy.LowLevelCallable.
But the following code gives me an error:
from numba import cfunc
import ctypes
def func(n_arg,x):
xe = x[0]
xh = x[1]
return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh)
nb_func = cfunc("float64(int64,CPointer(float64))")(func)
integrate.nquad(nb_func.ctypes, [[0,1],[0,1]], full_output=True)
error: quad: first argument is a ctypes function pointer with incorrect signature
Is it possible to compile a function with numba that can be used with nquad directly in the code and without defining the function in an external file?
Thank you very much in advance!
回答1:
Wrapping the function in a scipy.LowLevelCallable makes nquad
happy:
si.nquad(sp.LowLevelCallable(nb_func.ctypes), [[0,1],[0,1]], full_output=True)
# (-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323})
回答2:
The signature of the function you pass to nquad
should be double func(int n, double *xx)
. You can create a decorator for your function func
like so:
import numpy as np
import scipy.integrate as si
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def func(xe, xh):
return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh)
print(si.nquad(func, [[0,1],[0,1]], full_output=True))
>>>(-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323})
来源:https://stackoverflow.com/questions/45823212/how-can-you-implement-a-c-callable-from-numba-for-efficient-integration-with-nqu