How do I put a constraint on SciPy curve fit?

前端 未结 5 1988
旧巷少年郎
旧巷少年郎 2020-11-30 10:11

I\'m trying to fit the distribution of some experimental values with a custom probability density function. Obviously, the integral of the resulting function should always b

5条回答
  •  Happy的楠姐
    2020-11-30 10:29

    Following the example above here is more general way to add any constraints:

    from scipy.optimize import minimize
    from scipy.integrate import quad
    import matplotlib.pyplot as plt
    import numpy as np
    
    x = np.linspace(0, np.pi, 100)
    y = np.sin(x) + (0. + np.random.rand(len(x))*0.4)
    
    def func_to_fit(x, params):
        return params[0] + params[1] * x + params[2] * x ** 2 + params[3] * x ** 3
    
    def constr_fun(params):
        intgrl, _ = quad(func_to_fit, 0, np.pi, args=(params,))
        return intgrl - 2
    
    def func_to_minimise(params, x, y):
        y_pred = func_to_fit(x, params)
        return np.sum((y_pred - y) ** 2)
    
    # Do the parameter fitting
    #without constraints
    res1 = minimize(func_to_minimise, x0=np.random.rand(4), args=(x, y))
    params1 = res1.x
    # with constraints
    cons = {'type': 'eq', 'fun': constr_fun}
    res2 = minimize(func_to_minimise, x0=np.random.rand(4), args=(x, y), constraints=cons)
    params2 = res2.x
    
    y_fit1 = func_to_fit(x, params1)
    y_fit2 = func_to_fit(x, params2)
    
    plt.scatter(x,y, marker='.')
    plt.plot(x, y_fit2, color='y', label='constrained')
    plt.plot(x, y_fit1, color='g', label='curve_fit')
    plt.legend(); plt.xlim(-0.1,3.5); plt.ylim(0,1.4)
    plt.show()
    print(f"Constrant violation: {constr_fun(params1)}")
    

    Constraint violation: -2.9179325622408214e-10

提交回复
热议问题