How to reduce integration time for integration over 2D connected domains

寵の児 提交于 2020-05-07 03:53:10

问题


I need to compute many 2D integrations over domains that are simply connected (and convex most of the time). I'm using python function scipy.integrate.nquad to do this integration. However, the time required by this operation is significantly large compared to integration over a rectangular domain. Is there any faster implementation possible?

Here is an example; I integrate a constant function first over a circular domain (using a constraint inside the function) and then on a rectangular domain (default domain of nquad function).

from scipy import integrate
import time

def circular(x,y,a):
  if x**2 + y**2 < a**2/4:
    return 1 
  else:
    return 0

def rectangular(x,y,a):
  return 1

a = 4
start = time.time()
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
now = time.time()
print(now-start)

start = time.time()
result = integrate.nquad(rectangular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
now = time.time()
print(now-start)

The rectangular domain takes only 0.00029 seconds, while the circular domain requires 2.07061 seconds to complete.

Also the circular integration gives the following warning:

IntegrationWarning: The maximum number of subdivisions (50) has been achieved.
If increasing the limit yields no improvement it is advised to analyze 
the integrand in order to determine the difficulties.  If the position of a 
local difficulty can be determined (singularity, discontinuity) one will 
probably gain from splitting up the interval and calling the integrator 
on the subranges.  Perhaps a special-purpose integrator should be used.
**opt)

回答1:


One way to make the calculation faster is to use numba, a just-in-time compiler for Python.

The @jit decorator

Numba provides a @jit decorator to compile some Python code and output optimized machine code that can be run in parallel on several CPU. Jitting the integrand function only takes little effort and will achieve some time saving as the code is optimized to run faster. One doesn't even have to worry with types, Numba does all this under the hood.

from scipy import integrate
from numba import jit

@jit
def circular_jit(x, y, a):
    if x**2 + y**2 < a**2 / 4:
        return 1 
    else:
        return 0

a = 4
result = integrate.nquad(circular_jit, [[-a/2, a/2],[-a/2, a/2]], args=(a,))

This runs indeed faster and when timing it on my machine, I get:

 Original circular function: 1.599048376083374
 Jitted circular function: 0.8280022144317627

That is a ~50% reduction of computation time.

Scipy's LowLevelCallable

Function calls in Python are quite time consuming due to the nature of the language. The overhead can sometimes make Python code slow in comparison to compiled languages like C.

In order to mitigate this, Scipy provides a LowLevelCallable class which can be used to provide access to a low-level compiled callback function. Through this mechanism, Python's function call overhead is bypassed and further time saving can be made.

Note that in the case of nquad, the signature of the cfunc passed to LowerLevelCallable must be one of:

double func(int n, double *xx)
double func(int n, double *xx, void *user_data)

where the int is the number of arguments and the values for the arguments are in the second argument. user_data is used for callbacks that need context to operate.

We can therefore slightly change the circular function signature in Python to make it compatible.

from scipy import integrate, LowLevelCallable
from numba import cfunc
from numba.types import intc, CPointer, float64


@cfunc(float64(intc, CPointer(float64)))
def circular_cfunc(n, args):
    x, y, a = (args[0], args[1], args[2]) # Cannot do `(args[i] for i in range(n))` as `yield` is not supported
    if x**2 + y**2 < a**2/4:
        return 1 
    else:
        return 0

circular_LLC = LowLevelCallable(circular_cfunc.ctypes)

a = 4
result = integrate.nquad(circular_LLC, [[-a/2, a/2],[-a/2, a/2]], args=(a,))

With this method I get

LowLevelCallable circular function: 0.07962369918823242

This is a 95% reduction compared to the original and 90% when compared to the jitted version of the function.

A bespoke decorator

In order to make the code more tidy and to keep the integrand function's signature flexible, a bespoke decorator function can be created. It will jit the integrand function and wrap it into a LowLevelCallable object that can then be used with nquad.

from scipy import integrate, LowLevelCallable
from numba import cfunc, jit
from numba.types import intc, CPointer, float64

def jit_integrand_function(integrand_function):
    jitted_function = jit(integrand_function, nopython=True)

    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        return jitted_function(xx[0], xx[1], xx[2])
    return LowLevelCallable(wrapped.ctypes)


@jit_integrand_function
def circular(x, y, a):
    if x**2 + y**2 < a**2 / 4:
        return 1
    else:
        return 0

a = 4
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))

Arbitrary number of arguments

If the number of arguments is unknown, then we can use the convenient carray function provided by Numba to convert the CPointer(float64) to a Numpy array.

import numpy as np
from scipy import integrate, LowLevelCallable
from numba import cfunc, carray, jit
from numba.types import intc, CPointer, float64

def jit_integrand_function(integrand_function):
    jitted_function = jit(integrand_function, nopython=True)

    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        ar = carray(xx, n)
        return jitted_function(ar[0], ar[1], ar[2:])
    return LowLevelCallable(wrapped.ctypes)


@jit_integrand_function
def circular(x, y, a):
    if x**2 + y**2 < a[-1]**2 / 4:
        return 1
    else:
        return 0

ar = np.array([1, 2, 3, 4])
a = ar[-1]
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=ar)


来源:https://stackoverflow.com/questions/60600672/how-to-reduce-integration-time-for-integration-over-2d-connected-domains

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!