How to Fit to The Outer Shell of a Function

拟墨画扇 提交于 2019-12-02 06:35:35

This is an iterative approach similar to this post. It is different in the sense that the shape of the graph does not permit the use of convex hull. So the idea is to create a cost function that tries to minimize the area of the graph while paying high cost if a point is above the graph. Depending on the type of the graph in OP the cost function needs to be adapted. One also has to check if in the final result all points are really below the graph. Here one can fiddle with details of the cost function. One my, e.g., include an offset in the tanh like tanh( slope * ( x - offset) ) to push the solution farther away from the data.

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import leastsq

def g( x, a, s ): 
    return a * np.exp(-x**2 / s**2 )

def cost_function( params, xData, yData, slope, val ):
    a,s = params
    area = 0.5 * np.sqrt( np.pi ) * a * s
    diff = np.fromiter ( ( y - g( x, a, s) for x, y in zip( xData, yData ) ), np.float )
    cDiff = np.fromiter( ( val * ( 1 + np.tanh( slope * d ) ) for d in diff ), np.float )
    out = np.concatenate( [ [area] , cDiff ] )
    return out

xData = np.linspace( -5, 5, 500 )
yData = np.fromiter( (  g( x, .77, 2 ) * np.sin( 257.7 * x )**2 for x in xData ), np.float )


sol=[ [ 1, 2.2 ] ]
for i in range( 1, 6 ):
    solN, err = leastsq( cost_function, sol[-1] , args=( xData, yData, 10**i, 1 ) )
    sol += [ solN ]
    print sol

fig = plt.figure()
ax = fig.add_subplot( 1, 1, 1)
ax.scatter( xData, yData, s=1 ) 
for solN in sol:
    solY = np.fromiter( (  g( x, *solN ) for x in xData ), np.float )
    ax.plot( xData, solY ) 
plt.show()

giving

>> [0.8627445  3.55774814]
>> [0.77758636 2.52613376]
>> [0.76712184 2.1181137 ]
>> [0.76874125 2.01910211]
>> [0.7695663  2.00262339]

and

Here is a different approach using scipy's Differental Evolution module combined with a "brick wall", where if any predicted value during the fit is greater than the corresponding Y value, the fitting error is made extremely large. I have shamelessly poached code from the answer of @mikuszefski to generate the data used in this example.

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import warnings

from scipy.optimize import differential_evolution

def g( x, a, s ): 
    return a * np.exp(-x**2 / s**2 )

xData = np.linspace( -5, 5, 500 )
yData = np.fromiter( (  g( x, .77, 2 )* np.sin( 257.7 * x )**2 for x in xData ), np.float )


def Gauss(x, a, x0, sigma, offset):
    return a * np.exp(-np.power(x - x0,2) / (2 * np.power(sigma,2))) + offset


# function for genetic algorithm to minimize (sum of squared error)
def sumOfSquaredError(parameterTuple):
    warnings.filterwarnings("ignore") # do not print warnings by genetic algorithm
    val = Gauss(xData, *parameterTuple)
    multiplier = 1.0
    for i in range(len(val)):
        if val[i] < yData[i]: # ****** brick wall ******
            multiplier = 1.0E10
    return np.sum((multiplier * (yData - val)) ** 2.0)


def generate_Initial_Parameters():
    # min and max used for bounds
    maxX = max(xData)
    minX = min(xData)
    maxY = max(yData)
    minY = min(yData)

    minData = min(minX, minY)
    maxData = max(maxX, maxY)

    parameterBounds = []
    parameterBounds.append([minData, maxData]) # parameter bounds for a
    parameterBounds.append([minData, maxData]) # parameter bounds for x0
    parameterBounds.append([minData, maxData]) # parameter bounds for sigma
    parameterBounds.append([minData, maxData]) # parameter bounds for offset

    # "seed" the numpy random number generator for repeatable results
    result = differential_evolution(sumOfSquaredError, parameterBounds, seed=3, polish=False)
    return result.x

# generate initial parameter values
geneticParameters = generate_Initial_Parameters()

# create values for display of fitted function
y_fit = Gauss(xData, *geneticParameters)

plt.scatter(xData, yData, s=1 ) # plot the raw data
plt.plot(xData, y_fit) # plot the equation using the fitted parameters
plt.show()

print('parameters:', geneticParameters)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!