Curve fitting with broken power law in Python

一笑奈何 提交于 2021-02-08 12:12:13

问题


Im trying to follow and re-use a piece of code (with my own data) suggested by someone named @ThePredator (I couldn't comment on that thread since I don't currently have the required reputation of 50). The full code is as follows:

import numpy as np # This is the Numpy module
from scipy.optimize import curve_fit # The module that contains the curve_fit routine
import matplotlib.pyplot as plt # This is the matplotlib module which we use for plotting the result

""" Below is the function that returns the final y according to the conditions """

def fitfunc(x,a1,a2):
    y1 = (x**(a1) )[x<xc]
    y2 = (x**(a1-a2) )[x>xc]
    y3 = (0)[x==xc]
    y = np.concatenate((y1,y2,y3))
    return y

x = array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460])
y = array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17])

""" In the above code, we have imported 3 modules, namely Numpy, Scipy and  matplotlib """

popt,pcov = curve_fit(fitfunc,x,y,p0=(10.0,1.0)) #here we provide random initial parameters a1,a2

a1 = popt[0] 
a2 = popt[1]
residuals = y - fitfunc(x,a1,a2)
chi-sq = sum( (residuals**2)/fitfunc(x,a1,a2) ) # This is the chi-square for your fitted curve

""" Now if you need to plot, perform the code below """
curvey = fitfunc(x,a1,a2) # This is your y axis fit-line

plt.plot(x, curvey, 'red', label='The best-fit line')
plt.scatter(x,y, c='b',label='The data points')
plt.legend(loc='best')
plt.show()

Im having some problem running this code and the errors I get are as follows:

y3 = (0)[x==xc]

TypeError: 'int' object has no attribute 'getitem'

and also:

xc is undefined

I don't see anything missing in the code (xc shouldn't have to be defined?).

Could the author (@ThePredator) or someone else having knowledge about this please help me identify what i haven't seen.

  • New version of code:

    import numpy as np # This is the Numpy module
    from scipy.optimize import curve_fit 
    import matplotlib.pyplot as plt 
    
    def fitfunc(x, a1, a2, xc):
        if x.all() < xc:
          y = x**a1
        elif x.all() > xc:
          y = x**(a1 - a2) * x**a2
        else:
          y = 0
        return y
    
    xc = 2
    x = np.array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460])
    y = np.array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17])
    
    popt,pcov = curve_fit(fitfunc,x,y,p0=(1.0,1.0)) 
    
    a1 = popt[0] 
    a2 = popt[1]
    residuals = y - fitfunc(x, a1, a2, xc)
    chisq = sum((residuals**2)/fitfunc(x, a1, a2, xc)) 
    curvey = [fitfunc(val, a1, a2, xc) for val in x] #  y-axis fit-line
    
    plt.plot(x, curvey, 'red', label='The best-fit line')
    plt.scatter(x,y, c='b',label='The data points')
    plt.legend(loc='best')
    plt.show()
    

回答1:


There are multiple errors/typos in your code.

1) You cannot use - in your variable names in Python (chi-square should be chi_square for example)

2) You should from numpy import array or replace array with np.array. Currently the name array is not defined.

3) xc is not defined, you should set it before calling fitfunc().

4) y3 = (0)[x==xc] is not valid, should be (I think) y3 = np.zeros(len(x))[x==xc] or y3 = np.zeros(np.sum(x==xc))

Your use of fit_function() is wrong, because it changes the order of the images. What you want is:

def fit_function(x, a1, a2, xc):
    if x < xc:
        y = x**a1
    elif x > xc:
        y = x**(a1 - a2) * x**a2
    else:
        y = 0
    return y
xc = 2 #or any value you want
curvey = [fit_function(val, a1, a2, xc) for val in x]



回答2:


Hi Do the following to define your function, and it will solve. x is an array (or list) and it should return y as an array (or list). And then you can use it in curvefit.

def fit_function(x, a1, a2, xc):
    y = []
    for xx in x:
        if xx<xc:
            y.append(x**a1)
        elif xx>xc:
            y.append(x**(a1 - a2) * x**a2)
        else:
            y.append(0.0)
    return y   


来源:https://stackoverflow.com/questions/32271090/curve-fitting-with-broken-power-law-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!