Python 3D polynomial surface fit, order dependent

后端 未结 4 1134
不思量自难忘°
不思量自难忘° 2020-12-01 00:52

I am currently working with astronomical data among which I have comet images. I would like to remove the background sky gradient in these images due to the time of capture

4条回答
  •  挽巷
    挽巷 (楼主)
    2020-12-01 01:47

    According to the principle of Least squares, and imitate Kington's style, while move argument m to argument m_1 and argument m_2.

    import numpy as np
    import matplotlib.pyplot as plt
    
    import itertools
    
    
    # w = (Phi^T Phi)^{-1} Phi^T t
    # where Phi_{k, j + i (m_2 + 1)} = x_k^i y_k^j,
    #       t_k = z_k,
    #           i = 0, 1, ..., m_1,
    #           j = 0, 1, ..., m_2,
    #           k = 0, 1, ..., n - 1
    def polyfit2d(x, y, z, m_1, m_2):
        # Generate Phi by setting Phi as x^i y^j
        nrows = x.size
        ncols = (m_1 + 1) * (m_2 + 1)
        Phi = np.zeros((nrows, ncols))
        ij = itertools.product(range(m_1 + 1), range(m_2 + 1))
        for h, (i, j) in enumerate(ij):
            Phi[:, h] = x ** i * y ** j
        # Generate t by setting t as Z
        t = z
        # Generate w by solving (Phi^T Phi) w = Phi^T t
        w = np.linalg.solve(Phi.T.dot(Phi), (Phi.T.dot(t)))
        return w
    
    
    # t' = Phi' w
    # where Phi'_{k, j + i (m_2 + 1)} = x'_k^i y'_k^j
    #       t'_k = z'_k,
    #           i = 0, 1, ..., m_1,
    #           j = 0, 1, ..., m_2,
    #           k = 0, 1, ..., n' - 1
    def polyval2d(x_, y_, w, m_1, m_2):
        # Generate Phi' by setting Phi' as x'^i y'^j
        nrows = x_.size
        ncols = (m_1 + 1) * (m_2 + 1)
        Phi_ = np.zeros((nrows, ncols))
        ij = itertools.product(range(m_1 + 1), range(m_2 + 1))
        for h, (i, j) in enumerate(ij):
            Phi_[:, h] = x_ ** i * y_ ** j
        # Generate t' by setting t' as Phi' w
        t_ = Phi_.dot(w)
        # Generate z_ by setting z_ as t_
        z_ = t_
        return z_
    
    
    if __name__ == "__main__":
        # Generate x, y, z
        n = 100
        x = np.random.random(n)
        y = np.random.random(n)
        z = x ** 2 + y ** 2 + 3 * x ** 3 + y + np.random.random(n)
    
        # Generate w
        w = polyfit2d(x, y, z, m_1=3, m_2=2)
    
        # Generate x', y', z'
        n_ = 1000
        x_, y_ = np.meshgrid(np.linspace(x.min(), x.max(), n_),
                             np.linspace(y.min(), y.max(), n_))
        z_ = np.zeros((n_, n_))
        for i in range(n_):
            z_[i, :] = polyval2d(x_[i, :], y_[i, :], w, m_1=3, m_2=2)
    
        # Plot
        plt.imshow(z_, extent=(x_.min(), y_.max(), x_.max(), y_.min()))
        plt.scatter(x, y, c=z)
        plt.show()
    

提交回复
热议问题