Scipy curve_fit: how to plot the fitted curve beyond the data points?

末鹿安然 提交于 2019-12-31 07:18:11

问题


I have a number of data points and I used Scipy curve_fit to fit a curve to this data set. I now would like to plot the fit beyond the range of data points and I cannot find out how to do it.

Here is a simple example based on an exponential fitting:

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def exponential_fit(x, a, b, c):
    return a*np.exp(-b*x) + c

x = np.array([0, 1, 2, 3, 4, 5])
y = np.array([30, 50, 80, 160, 300, 580])
fitting_parameters, covariance = curve_fit(exponential_fit, x, y)
a, b, c = fitting_parameters

plt.plot(x, y, 'o', label='data')
plt.plot(x, exponential_fit(x, *fitting_parameters), '-', label='Fit')

plt.axis([0, 8, 0, 2000])
plt.legend()
plt.show()

This returns the following plot:

Now how can I extend the fitted (orange) curve so it goes up to x = 8? Please note that I do not want to create additional data points I just want to expand the range of the fitted curve.

Many thanks in advance.


回答1:


You have to define an extra data range for x to extend it beyond the data range given by your data points. You can even improve the representation and calculate more x values for the fit function:

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def exponential_fit(x, a, b, c):
    return a*np.exp(-b*x) + c

x = np.array([0, 1, 2, 3, 4, 5])
y = np.array([30, 50, 80, 160, 300, 580])
fitting_parameters, covariance = curve_fit(exponential_fit, x, y)
a, b, c = fitting_parameters

x_min = -4  
x_max = 8                                #min/max values for x axis
x_fit = np.linspace(x_min, x_max, 100)   #range of x values used for the fit function
plt.plot(x, y, 'o', label='data')
plt.plot(x_fit, exponential_fit(x_fit, *fitting_parameters), '-', label='Fit')

plt.axis([x_min, x_max, 0, 2000])
plt.legend()
plt.show()

For added flexibility, I introduced x_min, x_max, because the same values are used to calculate the range for x values used by the fit function and to scale the axis for the plot. numpy.linspace creates an evenly spaced sample between start and stop value, used as x values to calculate the corresponding y values in the fit function.




回答2:


x ranges from 0 to 5. If you want the curve to go up to 8 (or up to eleven) you need to supply an array which ranges to eleven... sorry 8.

x_new = np.linspace(0,11)
plt.plot(x_new, exponential_fit(x_new, *fitting_parameters), '-', label='Fit')


来源:https://stackoverflow.com/questions/48506782/scipy-curve-fit-how-to-plot-the-fitted-curve-beyond-the-data-points

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!