Solving a system of odes (with changing constant!) using scipy.integrate.odeint?

后端 未结 2 1330
情话喂你
情话喂你 2020-12-09 22:23

I currently have a system of odes with a time-dependent constant. E.g.

def fun(u, t, a, b, c):
    x = u[0]
    y = u[1]
    z = u[2]
    dx_dt = a * x + y *         


        
2条回答
  •  南笙
    南笙 (楼主)
    2020-12-09 22:53

    Yes, this is possible. In the case where a is constant, I guess you called scipy.integrate.odeint(fun, u0, t, args) where fun is defined as in your question, u0 = [x0, y0, z0] is the initial condition, t is a sequence of time points for which to solve for the ODE and args = (a, b, c) are the extra arguments to pass to fun.

    In the case where a depends on time, you simply have to reconsider a as a function, for example (given a constant a0):

    def a(t):
        return a0 * t
    

    Then you will have to modify fun which computes the derivative at each time step to take the previous change into account:

    def fun(u, t, a, b, c):
        x = u[0]
        y = u[1]
        z = u[2]
        dx_dt = a(t) * x + y * z # A change on this line: a -> a(t)
        dy_dt = b * (y - z)
        dz_dt = - x * y + c * y - z
        return [dx_dt, dy_dt, dz_dt]
    

    Eventually, note that u0, t and args remain unchanged and you can again call scipy.integrate.odeint(fun, u0, t, args).

    A word about the correctness of this approach. The performance of the approximation of the numerical integration is affected, I don't know precisely how (no theoretical guarantees) but here is a simple example which works:

    import matplotlib.pyplot as plt
    import numpy as np
    import scipy as sp
    import scipy.integrate
    
    tmax = 10.0
    
    def a(t):
        if t < tmax / 2.0:
            return ((tmax / 2.0) - t) / (tmax / 2.0)
        else:
            return 1.0
    
    def func(x, t, a):
        return - (x - a(t))
    
    x0 = 0.8
    t = np.linspace(0.0, tmax, 1000)
    args = (a,)
    y = sp.integrate.odeint(func, x0, t, args)
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    h1, = ax.plot(t, y)
    h2, = ax.plot(t, [a(s) for s in t])
    ax.legend([h1, h2], ["y", "a"])
    ax.set_xlabel("t")
    ax.grid()
    plt.show()
    

    I Hope this will help you.

提交回复
热议问题