Find and draw regression plane to a set of points

前端 未结 3 1908
傲寒
傲寒 2020-12-29 10:40

I want to fit a plane to some data points and draw it. My current code is this:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.         


        
相关标签:
3条回答
  • 2020-12-29 10:50

    Another way is with a straight forward least squares solution. The equation for a plane is: ax + by + c = z. So set up matrices like this with all your data:

        x_0   y_0   1  
    A = x_1   y_1   1  
              ... 
        x_n   y_n   1  
    

    And

        a  
    x = b  
        c
    

    And

        z_0   
    B = z_1   
        ...   
        z_n
    

    In other words: Ax = B. Now solve for x which are your coefficients. But since (I assume) you have more than 3 points, the system is over-determined so you need to use the left pseudo inverse. So the answer is:

    a 
    b = (A^T A)^-1 A^T B
    c
    

    And here is some simple Python code with an example:

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    
    N_POINTS = 10
    TARGET_X_SLOPE = 2
    TARGET_y_SLOPE = 3
    TARGET_OFFSET  = 5
    EXTENTS = 5
    NOISE = 5
    
    # create random data
    xs = [np.random.uniform(2*EXTENTS)-EXTENTS for i in range(N_POINTS)]
    ys = [np.random.uniform(2*EXTENTS)-EXTENTS for i in range(N_POINTS)]
    zs = []
    for i in range(N_POINTS):
        zs.append(xs[i]*TARGET_X_SLOPE + \
                  ys[i]*TARGET_y_SLOPE + \
                  TARGET_OFFSET + np.random.normal(scale=NOISE))
    
    # plot raw data
    plt.figure()
    ax = plt.subplot(111, projection='3d')
    ax.scatter(xs, ys, zs, color='b')
    
    # do fit
    tmp_A = []
    tmp_b = []
    for i in range(len(xs)):
        tmp_A.append([xs[i], ys[i], 1])
        tmp_b.append(zs[i])
    b = np.matrix(tmp_b).T
    A = np.matrix(tmp_A)
    fit = (A.T * A).I * A.T * b
    errors = b - A * fit
    residual = np.linalg.norm(errors)
    
    print "solution:"
    print "%f x + %f y + %f = z" % (fit[0], fit[1], fit[2])
    print "errors:"
    print errors
    print "residual:"
    print residual
    
    # plot plane
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    X,Y = np.meshgrid(np.arange(xlim[0], xlim[1]),
                      np.arange(ylim[0], ylim[1]))
    Z = np.zeros(X.shape)
    for r in range(X.shape[0]):
        for c in range(X.shape[1]):
            Z[r,c] = fit[0] * X[r,c] + fit[1] * Y[r,c] + fit[2]
    ax.plot_wireframe(X,Y,Z, color='k')
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    plt.show()
    
    0 讨论(0)
  • 2020-12-29 10:50

    Thanks @Ben for sharing! Since np.matrix is deprecated, I edited your code so it works with np arrays

    import matplotlib.pyplot as plt
    import numpy as np
    from numpy.linalg import inv
    
    # Pass the function array of points, shape (3, X)
    def plane_from_points(points):
        # Create this matrix correctly without transposing it later?
        A = np.array([
            points[0,:],
            points[1,:],
            np.ones(points.shape[1])
        ]).T
        b = np.array([points[2, :]]).T
    
        # fit = (A.T * A).I * A.T * b
        fit = np.dot(np.dot(inv(np.dot(A.T, A)), A.T), b)
        # errors = b - np.dot(A, fit)
        # residual = np.linalg.norm(errors)
        return fit
    
    N_POINTS = 10
    TARGET_X_SLOPE = 2
    TARGET_y_SLOPE = 3
    TARGET_OFFSET  = 5
    EXTENTS = 5
    NOISE = 3
    
    # create random data
    xs = [np.random.uniform(2*EXTENTS)-EXTENTS for i in range(N_POINTS)]
    ys = [np.random.uniform(2*EXTENTS)-EXTENTS for i in range(N_POINTS)]
    zs = []
    for i in range(N_POINTS):
        zs.append(xs[i]*TARGET_X_SLOPE + \
                  ys[i]*TARGET_y_SLOPE + \
                  TARGET_OFFSET + np.random.normal(scale=NOISE))
    
    points = np.array([xs, ys, zs])
    
    # plot raw data
    plt.figure()
    ax = plt.subplot(111, projection='3d')
    ax.scatter(xs, ys, zs, color='b')
    
    fit = plane_from_points(points)
    # plot plane
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    X,Y = np.meshgrid(np.arange(xlim[0], xlim[1]),
                      np.arange(ylim[0], ylim[1]))
    
    Z = np.zeros(X.shape)
    for r in range(X.shape[0]):
        for c in range(X.shape[1]):
            Z[r,c] = fit[0] * X[r,c] + fit[1] * Y[r,c] + fit[2]
    
    ax.plot_wireframe(X,Y,Z, color='k')
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    plt.show()
    
    0 讨论(0)
  • 2020-12-29 11:05

    Oh, the idea just came to my mind. It's quite easy. :-)

    import numpy as np
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    import scipy.optimize
    import functools
    
    def plane(x, y, params):
        a = params[0]
        b = params[1]
        c = params[2]
        z = a*x + b*y + c
        return z
    
    def error(params, points):
        result = 0
        for (x,y,z) in points:
            plane_z = plane(x, y, params)
            diff = abs(plane_z - z)
            result += diff**2
        return result
    
    def cross(a, b):
        return [a[1]*b[2] - a[2]*b[1],
                a[2]*b[0] - a[0]*b[2],
                a[0]*b[1] - a[1]*b[0]]
    
    points = [(1.1,2.1,8.1),
              (3.2,4.2,8.0),
              (5.3,1.3,8.2),
              (3.4,2.4,8.3),
              (1.5,4.5,8.0)]
    
    fun = functools.partial(error, points=points)
    params0 = [0, 0, 0]
    res = scipy.optimize.minimize(fun, params0)
    
    a = res.x[0]
    b = res.x[1]
    c = res.x[2]
    
    xs, ys, zs = zip(*points)
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    ax.scatter(xs, ys, zs)
    
    point  = np.array([0.0, 0.0, c])
    normal = np.array(cross([1,0,a], [0,1,b]))
    d = -point.dot(normal)
    xx, yy = np.meshgrid([-5,10], [-5,10])
    z = (-normal[0] * xx - normal[1] * yy - d) * 1. /normal[2]
    ax.plot_surface(xx, yy, z, alpha=0.2, color=[0,1,0])
    
    ax.set_xlim(-10,10)
    ax.set_ylim(-10,10)
    ax.set_zlim(  0,10)
    
    plt.show()
    

    regression plane

    Sorry for asking unnecessarily.

    0 讨论(0)
提交回复
热议问题