need to understand better how rtol, atol work in scipy.integrate.odeint

半腔热情 提交于 2019-12-06 06:37:54

问题


Here scipy.integrate.odeint is called with six different standard ode problems with rtol = atol from 1E-06 to 1E-13. I've looked at the max difference between the results at all larger tolerances minus those of the smallest, to get some kind of representation of "error". I'm curious why, for a given tolerance, one problem (D5) gives errors a million times worse than another problem (C1), even though the range in number of steps is fairly tight (within a factor of 10).

The citation for the ode problems is given in the script. All problems are fairly well normalized so I'm treating rtol and atol similarly.

To reiterate - my question is why the errors vary by a factor of almost 1E+06 between different problems, though the errors scale with tolerance. Of course C1 is the "softest" and D5 has the dramatic peaks at "perihelion" but I was thinking that the routine would adjust the step sizes internally so that the errors would be similar.

EDIT: I've added the time evolution of the "errors" which may shed some light.

# FROM: "Comparing Numerical Methods for Ordinary Differential Equations"
# T.E. Hull, W.H. Enright, B.M. Fellen and A.E. Sedgwidh
# SIAM J. Numer. Anal. vol 9, no 4, December 1972, pp: 603-637

def deriv_B1(y, x):
    return [2.*(y[0]-y[0]*y[1]), -(y[1]-y[0]*y[1])] # "growth of two conflicting populations"

def deriv_B4(y, x):
    A = 1./np.sqrt(y[0]**2 + y[1]**2)
    return [-y[1] - A*y[0]*y[2],  y[0] - A*y[1]*y[2],  A*y[0]]  # "integral surface of a torus"

def deriv_C1(y, x):
    return [-y[0]] + [y[i]-y[i+1] for i in range(8)] + [y[8]] # a radioactive decay chain

def deriv_D1toD5(y, x):
    A = -(y[0]**2 + y[1]**2)**-1.5
    return [y[2],  y[3],  A*y[0],  A*y[1]] # dimensionless orbit equation

deriv_D1, deriv_D5 = deriv_D1toD5, deriv_D1toD5

def deriv_E1(y, x):
    return [y[1], -(y[1]/(x+1.0) + (1.0 - 0.25/(x+1.0)**2)*y[0])] # derived from Bessel's equation of order 1/2

def deriv_E3(y, x):
    return [y[1], y[0]**3/6.0 - y[0] + 2.0*np.sin(2.78535*x)] # derived from Duffing's equation

import numpy as np
from scipy.integrate import odeint as ODEint
import matplotlib.pyplot as plt
import timeit

y0_B1 = [1.0, 3.0]
y0_B4 = [3.0, 0.0, 0.0]
y0_C1 = [1.0] + [0.0 for i in range(9)]
ep1, ep5 = 0.1, 0.9
y0_D1 = [1.0-ep1, 0.0, 0.0, np.sqrt((1.0+ep1)/(1.0-ep1))]
y0_D5 = [1.0-ep5, 0.0, 0.0, np.sqrt((1.0+ep5)/(1.0-ep5))]
y0_E1 = [0.6713968071418030, 0.09540051444747446] # J(1/2, 1), Jprime(1/2, 1)
y0_E3 = [0.0, 0.0]

x  = np.linspace(0, 20, 51)
xa = np.linspace(0, 20, 2001)

derivs = [deriv_B1, deriv_B4, deriv_C1, deriv_D1, deriv_D5, deriv_E3]
names  = ["deriv_B1", "deriv_B4", "deriv_C1", "deriv_D1", "deriv_D5", "deriv_E3"]
y0s    = [y0_B1, y0_B4, y0_C1, y0_D1, y0_D5, y0_E3]

timeit_dict, answer_dict, info_dict = dict(), dict(), dict()

ntimes = 10
tols   = [10.**-i for i in range(6, 14)]

def F():           # low density of time points, no output for speed test
    ODEint(deriv, y0, x, rtol=tol, atol=tol)
def Fa():           # hight density of time points, full output for plotting
    return ODEint(deriv, y0, xa, rtol=tol, atol=tol, full_output=True)

for deriv, y0, name in zip(derivs, y0s, names):
    timez = [timeit.timeit(F, number=ntimes)/float(ntimes) for tol in tols]
    timeit_dict[name] = timez
    alist, dlist = zip(*[Fa() for tol in tols])
    answer_dict[name] = np.array([a.T for a in alist])
    info_dict[name] = dlist

plt.figure(figsize=[10,6])

for i, name in enumerate(names):
    plt.subplot(2, 3, i+1)
    for thing in answer_dict[name][-1]:
        plt.plot(xa, thing)
    plt.title(name[-2:], fontsize=16)
plt.show()

plt.figure(figsize=[10, 8])
for i, name in enumerate(names):
    plt.subplot(2,3,i+1)
    a = answer_dict[name]
    a13, a10, a8 = a[-1], a[-4], a[-6]
    d10 = np.abs(a10-a13).max(axis=0)
    d8  = np.abs(a8 -a13).max(axis=0)
    plt.plot(xa, d10, label="tol(1E-10)-tol(1E-13)")
    plt.plot(xa, d8,  label="tol(1E-08)-tol(1E-13)")
    plt.yscale('log')
    plt.ylim(1E-11, 1E-03)
    plt.title(name[-2:], fontsize=16)
    if i==3:
        plt.text(3, 1E-10, "1E-10 - 1E-13", fontsize=14)
        plt.text(2, 2E-05, "1E-08 - 1E-13", fontsize=14)
plt.show()

fs = 16
plt.figure(figsize=[12,6])

plt.subplot(1,3,1)
for name in names:
    plt.plot(tols, timeit_dict[name])
plt.title("timing results", fontsize=16)
plt.xscale('log')
plt.yscale('log')
plt.text(1E-09, 5E-02, "D5", fontsize=fs)
plt.text(1E-09, 4.5E-03, "C1", fontsize=fs)

plt.subplot(1,3,2)
for name in names:
    a = answer_dict[name]
    e = a[:-1] - a[-1]
    em = [np.abs(thing).max() for thing in e]
    plt.plot(tols[:-1], em)
plt.title("max difference from smallest tol", fontsize=16)
plt.xscale('log')
plt.yscale('log')
plt.xlim(min(tols), max(tols))
plt.text(1E-09, 3E-03, "D5", fontsize=fs)
plt.text(1E-09, 8E-11, "C1", fontsize=fs)

plt.subplot(1,3,3)
for name in names:
    nsteps = [d['nst'][-1] for d in info_dict[name]]
    plt.plot(tols, nsteps, label=name[-2:])
plt.title("number of steps", fontsize=16)
plt.xscale('log')
plt.yscale('log')
plt.ylim(3E+01, 3E+03)
plt.legend(loc="upper right", shadow=False, fontsize="large")
plt.text(2E-12, 2.3E+03, "D5", fontsize=fs)
plt.text(2E-12, 1.5E+02, "C1", fontsize=fs)

plt.show()

回答1:


Since I posted the question, I've learned more. One can't just multiply the numerical accuracy per step by the number of steps, and hope to get the overall accuracy.

If solutions diverge (nearby starting points lead to paths which become much farther apart over time) then numerical errors can become amplified. Every problem will be different - all is as it should be.

Hull et al. is a great place to start when learning about ODE solvers. (the source for the problems shown in the question)

"Comparing Numerical Methods for Ordinary Differential Equations" T.E. Hull, W.H. Enright, B.M. Fellen and A.E. Sedgwidh SIAM J. Numer. Anal. vol 9, no 4, December 1972, pp: 603-637



来源:https://stackoverflow.com/questions/33748601/need-to-understand-better-how-rtol-atol-work-in-scipy-integrate-odeint

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!