Python: two-curve gaussian fitting with non-linear least-squares

前端 未结 3 741
野性不改
野性不改 2020-12-13 07:39

My knowledge of maths is limited which is why I am probably stuck. I have a spectra to which I am trying to fit two Gaussian peaks. I can fit to the largest peak, but I cann

3条回答
  •  悲&欢浪女
    2020-12-13 07:51

    This code worked for me providing that you are only fitting a function that is a combination of two Gaussian distributions.

    I just made a residuals function that adds two Gaussian functions and then subtracts them from the real data.

    The parameters (p) that I passed to Numpy's least squares function include: the mean of the first Gaussian function (m), the difference in the mean from the first and second Gaussian functions (dm, i.e. the horizontal shift), the standard deviation of the first (sd1), and the standard deviation of the second (sd2).

    import numpy as np
    from scipy.optimize import leastsq
    import matplotlib.pyplot as plt
    
    ######################################
    # Setting up test data
    def norm(x, mean, sd):
      norm = []
      for i in range(x.size):
        norm += [1.0/(sd*np.sqrt(2*np.pi))*np.exp(-(x[i] - mean)**2/(2*sd**2))]
      return np.array(norm)
    
    mean1, mean2 = 0, -2
    std1, std2 = 0.5, 1 
    
    x = np.linspace(-20, 20, 500)
    y_real = norm(x, mean1, std1) + norm(x, mean2, std2)
    
    ######################################
    # Solving
    m, dm, sd1, sd2 = [5, 10, 1, 1]
    p = [m, dm, sd1, sd2] # Initial guesses for leastsq
    y_init = norm(x, m, sd1) + norm(x, m + dm, sd2) # For final comparison plot
    
    def res(p, y, x):
      m, dm, sd1, sd2 = p
      m1 = m
      m2 = m1 + dm
      y_fit = norm(x, m1, sd1) + norm(x, m2, sd2)
      err = y - y_fit
      return err
    
    plsq = leastsq(res, p, args = (y_real, x))
    
    y_est = norm(x, plsq[0][0], plsq[0][2]) + norm(x, plsq[0][0] + plsq[0][1], plsq[0][3])
    
    plt.plot(x, y_real, label='Real Data')
    plt.plot(x, y_init, 'r.', label='Starting Guess')
    plt.plot(x, y_est, 'g.', label='Fitted')
    plt.legend()
    plt.show()
    

    Results of the code.

提交回复
热议问题