问题
Im trying to follow and re-use a piece of code (with my own data) suggested by someone named @ThePredator (I couldn't comment on that thread since I don't currently have the required reputation of 50). The full code is as follows:
import numpy as np # This is the Numpy module
from scipy.optimize import curve_fit # The module that contains the curve_fit routine
import matplotlib.pyplot as plt # This is the matplotlib module which we use for plotting the result
""" Below is the function that returns the final y according to the conditions """
def fitfunc(x,a1,a2):
y1 = (x**(a1) )[x<xc]
y2 = (x**(a1-a2) )[x>xc]
y3 = (0)[x==xc]
y = np.concatenate((y1,y2,y3))
return y
x = array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460])
y = array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17])
""" In the above code, we have imported 3 modules, namely Numpy, Scipy and matplotlib """
popt,pcov = curve_fit(fitfunc,x,y,p0=(10.0,1.0)) #here we provide random initial parameters a1,a2
a1 = popt[0]
a2 = popt[1]
residuals = y - fitfunc(x,a1,a2)
chi-sq = sum( (residuals**2)/fitfunc(x,a1,a2) ) # This is the chi-square for your fitted curve
""" Now if you need to plot, perform the code below """
curvey = fitfunc(x,a1,a2) # This is your y axis fit-line
plt.plot(x, curvey, 'red', label='The best-fit line')
plt.scatter(x,y, c='b',label='The data points')
plt.legend(loc='best')
plt.show()
Im having some problem running this code and the errors I get are as follows:
y3 = (0)[x==xc]
TypeError: 'int' object has no attribute 'getitem'
and also:
xc is undefined
I don't see anything missing in the code (xc shouldn't have to be defined?).
Could the author (@ThePredator) or someone else having knowledge about this please help me identify what i haven't seen.
New version of code:
import numpy as np # This is the Numpy module from scipy.optimize import curve_fit import matplotlib.pyplot as plt def fitfunc(x, a1, a2, xc): if x.all() < xc: y = x**a1 elif x.all() > xc: y = x**(a1 - a2) * x**a2 else: y = 0 return y xc = 2 x = np.array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460]) y = np.array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17]) popt,pcov = curve_fit(fitfunc,x,y,p0=(1.0,1.0)) a1 = popt[0] a2 = popt[1] residuals = y - fitfunc(x, a1, a2, xc) chisq = sum((residuals**2)/fitfunc(x, a1, a2, xc)) curvey = [fitfunc(val, a1, a2, xc) for val in x] # y-axis fit-line plt.plot(x, curvey, 'red', label='The best-fit line') plt.scatter(x,y, c='b',label='The data points') plt.legend(loc='best') plt.show()
回答1:
There are multiple errors/typos in your code.
1) You cannot use - in your variable names in Python (chi-square should be chi_square for example)
2) You should from numpy import array or replace array with np.array. Currently the name array is not defined.
3) xc is not defined, you should set it before calling fitfunc().
4) y3 = (0)[x==xc] is not valid, should be (I think) y3 = np.zeros(len(x))[x==xc] or y3 = np.zeros(np.sum(x==xc))
Your use of fit_function() is wrong, because it changes the order of the images. What you want is:
def fit_function(x, a1, a2, xc):
if x < xc:
y = x**a1
elif x > xc:
y = x**(a1 - a2) * x**a2
else:
y = 0
return y
xc = 2 #or any value you want
curvey = [fit_function(val, a1, a2, xc) for val in x]
回答2:
Hi Do the following to define your function, and it will solve. x is an array (or list) and it should return y as an array (or list). And then you can use it in curvefit.
def fit_function(x, a1, a2, xc):
y = []
for xx in x:
if xx<xc:
y.append(x**a1)
elif xx>xc:
y.append(x**(a1 - a2) * x**a2)
else:
y.append(0.0)
return y
来源:https://stackoverflow.com/questions/32271090/curve-fitting-with-broken-power-law-in-python