问题
I am trying to fit a top hat function to some data, ie. f(x) is constant for the entire real line, except for one segment of finite length which is equal to another constant. My parameters are the two constants of the tophat function, the midpoint, and the width and I'm trying to use scipy.optimize.curve_fit to get all of these. Unfortunately, curve_fit is having trouble obtaining the width of the hat. No matter what I do, it refuses to test any value of the width other than the one I start with, and fits the rest of the data very badly. The following code snippet illustrates the problem:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
def tophat(x, base_level, hat_level, hat_mid, hat_width):
ret=[]
for xx in x:
if hat_mid-hat_width/2. < xx < hat_mid+hat_width/2.:
ret.append(hat_level)
else:
ret.append(base_level)
return np.array(ret)
x = np.arange(-10., 10., 0.01)
y = tophat(x, 1.0, 5.0, 0.0, 1.0)+np.random.rand(len(x))*0.2-0.1
guesses = [ [1.0, 5.0, 0.0, 1.0],
[1.0, 5.0, 0.0, 0.1],
[1.0, 5.0, 0.0, 2.0] ]
plt.plot(x,y)
for guess in guesses:
popt, pcov = curve_fit( tophat, x, y, p0=guess )
print popt
plt.plot( x, tophat(x, popt[0], popt[1], popt[2], popt[3]) )
plt.show()
Why is curve_fit so extremely terrible at getting this right, and how can I fix it?
回答1:
First, the definition of tophat
could use numpy.where
instead of a loop:
def tophat(x, base_level, hat_level, hat_mid, hat_width):
return np.where((hat_mid-hat_width/2. < x) & (x < hat_mid+hat_width/2.), hat_level, base_level)
Second, the tricky discontinuous objective function resists the optimization algorithms that curve_fit
calls. The Nelder-Mead method is usually preferable for rough functions, but it looks like curve_fit
cannot use it. So I set up an objective function (just the sum of absolute values of deviations) and minimize that:
def objective(params, x, y):
return np.sum(np.abs(tophat(x, *params) - y))
plt.plot(x,y)
for guess in guesses:
res = minimize(objective, guess, args=(x, y), method='Nelder-Mead')
print(res.x)
plt.plot(x, tophat(x, *(res.x)))
The results are better, in that starting with a too-wide hat of width 2 makes it shrink down to the correct size (see the last of three guesses).
[9.96041297e-01 5.00035502e+00 2.39462103e-04 9.99759984e-01]
[ 1.00115808e+00 4.94088711e+00 -2.21340843e-05 1.04924153e-01]
[9.95947108e-01 4.99871040e+00 1.26575116e-03 9.97908018e-01]
Unfortunately, when the starting guess is a too-narrow hat, the optimizer is still stuck.
You can try other optimization method / objective function combinations but I haven't found one that makes the hat reliably expand.
One thing to try is not to use the parameters that are too close to the true levels; this sometimes might hurt. With
guesses = [ [1.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 0.0, 0.1],
[1.0, 1.0, 0.0, 2.0] ]
I once managed to get
[ 1.00131181 4.99156649 -0.01109271 0.96822019]
[ 1.00137925 4.97879423 -0.05091561 1.096166 ]
[ 1.00130568 4.98679988 -0.01133717 0.99339777]
which is correct for all three widths. However, this was only on some of several tries (there is some randomness in the initialization of the optimizing procedure). Some other attempts with the same initial points failed; the process is not robust enough.
回答2:
By its nature, non-linear least-squares fitting as with curve_fit()
works with real, floating-point numbers and is not good at dealing with discrete variables. In the fit process, small changes (like, at the 1e-7 level) are made to each variable, and the effect of that small change on the fit result is used to decide how to change that variable to improve the fit. With discretely sampled data, small changes to your hat_mid
and/or hat_width
could easily be smaller than the spacing of data points and so have no effect at all on the fit. That is why curve_fit
is "extremely terrible" at this problem.
You may find that giving a finite width (that is, comparable to the step size of your discrete data) to the steps helps to better find where the edges of you hat are.
来源:https://stackoverflow.com/questions/49878701/scipy-curve-fit-cannot-fit-a-tophat-function