Plot linear model in 3d with Matplotlib

前端 未结 2 1398
自闭症患者
自闭症患者 2020-12-09 07:07

I\'m trying to create a 3d plot of a linear model fit for a data set. I was able to do this relatively easily in R, but I\'m really struggling to do the same in Python. Here

相关标签:
2条回答
  • 2020-12-09 07:24

    Got it!

    The problem that I talk about in the comments to mdurant's answer is that the surface is not plotted as a nice square pattern like these Combining scatter plot with surface plot.

    I realized that the problem was my meshgrid, so I corrected both ranges (x and y) and used proportional steps for np.arange.

    This allowed me to use the code provided by mdurant's answer and it worked perfectly!

    Here's the result:

    3d scatter plot with OLS plane

    And here's the code:

    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import statsmodels.formula.api as sm
    from matplotlib import cm
    
    csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0)
    model = sm.ols(formula='Sales ~ TV + Radio', data = csv)
    fit = model.fit()
    
    fit.summary()
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    x_surf = np.arange(0, 350, 20)                # generate a mesh
    y_surf = np.arange(0, 60, 4)
    x_surf, y_surf = np.meshgrid(x_surf, y_surf)
    
    exog = pd.core.frame.DataFrame({'TV': x_surf.ravel(), 'Radio': y_surf.ravel()})
    out = fit.predict(exog = exog)
    ax.plot_surface(x_surf, y_surf,
                    out.reshape(x_surf.shape),
                    rstride=1,
                    cstride=1,
                    color='None',
                    alpha = 0.4)
    
    ax.scatter(csv['TV'], csv['Radio'], csv['Sales'],
               c='blue',
               marker='o',
               alpha=1)
    
    ax.set_xlabel('TV')
    ax.set_ylabel('Radio')
    ax.set_zlabel('Sales')
    
    plt.show()
    
    0 讨论(0)
  • 2020-12-09 07:30

    You were correct in assuming that plot_surface wants a meshgrid of coordinates to work with, but predict wants a data structure like the one you fitted with (the "exog").

    exog = pd.core.frame.DataFrame({'TV':xx.ravel(),'Radio':yy.ravel()})
    out = fit.predict(exog=exog)
    ax.plot_surface(xx, yy, out.reshape(xx.shape), color='None')
    
    0 讨论(0)
提交回复
热议问题