How to use if statement in a differential equation (SciPy)?

耗尽温柔 提交于 2019-12-04 12:05:43

Recommended solution

This uses events and integrates separately after each discontinuity.

import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

a = 0.02
b = 0.2
c = -65
d = 8
i = 0

p = [a,b,c,d,i]

# Define event function and make it a terminal event
def event(t, u):
    return u[0] - 30
event.terminal = True

# Define differential equation
def fun(t, u):
    du = [(0.04*u[0] + 5)*u[0] + 150 - u[1] - p[4],
          p[0]*(p[1]*u[0]-u[1])]
    return du

u = [0,0]

ts = []
ys = []
t = 0
tend = 100
while True:
    sol = solve_ivp(fun, (t, tend), u, events=event)
    ts.append(sol.t)
    ys.append(sol.y)
    if sol.status == 1: # Event was hit
        # New start time for integration
        t = sol.t[-1]
        # Reset initial state
        u = sol.y[:, -1].copy()
        u[0] = p[2] #reset to -65    
        u[1] = u[1] + p[3]
    else:
        break

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
# We have to stitch together the separate simulation results for plotting
ax.plot(np.concatenate(ts), np.concatenate(ys, axis=1).T)
myleg = plt.legend(['v','u'])

Minimum change "solution"

It appears as though your approach works just fine with solve_ivp.

Warning I think in both Julia and solve_ivp, the correct way to handle this kind of thing is to use events. I believe the approach below relies on an implementation detail, which is that the state vector passed to the function is the same object as the internal state vector, which allows us to modify it in place. If it were a copy, this approach wouldn't work. In addition, there is no guarantee in this approach that the solver is taking small enough steps that the correct point where the limit is reached will be stepped on. Using events will make this more correct and generalisable to other differential equations which perhaps have lower gradients before the discontinuity.

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FormatStrFormatter
from scipy.integrate import solve_ivp
plt.close('all')

a = 0.02
b = 0.2
c = -65
d = 8
i = 0

p = [a,b,c,d,i]

def fun(t, u):
    du = [0,0]
    if u[0] < 30: #Checking if the threshold has been reached
        du[0] = (0.04*u[0] + 5)*u[0] + 150 - u[1] - p[4]
        du[1] = p[0]*(p[1]*u[0]-u[1])
    else:
        u[0] = p[2] #reset to -65    
        u[1] = u[1] + p[3] 

    return du

y0 = [0,0]

tspan = (0,100)
sol = solve_ivp(fun, tspan, y0)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)         
plt.plot(sol.t,sol.y[0, :],'k',linewidth = 5)
plt.plot(sol.t,sol.y[1, :],'r',linewidth = 5)
myleg = plt.legend(['v','u'],loc='upper right',prop = {'size':28,'weight':'bold'}, bbox_to_anchor=(1,0.9))

Result

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!