Plot normal distribution in 3D

后端 未结 3 600
无人及你
无人及你 2020-12-29 15:38

I am trying to plot the comun distribution of two normal distributed variables.

The code below plots one normal distributed variable. What would the code be for plo

3条回答
  •  萌比男神i
    2020-12-29 16:15

    While the other answers are great, I wanted to achieve similar results while also illustrating the distribution with a scatter plot of the sample.

    More details can be found here: Python 3d plot of multivariate gaussian distribution

    The results looks like:

    And is generated using the following code:

    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from scipy.stats import multivariate_normal
    
    
    # Sample parameters
    mu = np.array([0, 0])
    sigma = np.array([[0.7, 0.2], [0.2, 0.3]])
    rv = multivariate_normal(mu, sigma)
    sample = rv.rvs(500)
    
    # Bounds parameters
    x_abs = 2.5
    y_abs = 2.5
    x_grid, y_grid = np.mgrid[-x_abs:x_abs:.02, -y_abs:y_abs:.02]
    
    pos = np.empty(x_grid.shape + (2,))
    pos[:, :, 0] = x_grid
    pos[:, :, 1] = y_grid
    
    levels = np.linspace(0, 1, 40)
    
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    # Removes the grey panes in 3d plots
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    
    # The heatmap
    ax.contourf(x_grid, y_grid, 0.1 * rv.pdf(pos),
                zdir='z', levels=0.1 * levels, alpha=0.9)
    
    # The wireframe
    ax.plot_wireframe(x_grid, y_grid, rv.pdf(
        pos), rstride=10, cstride=10, color='k')
    
    # The scatter. Note that the altitude is defined based on the pdf of the
    # random variable
    ax.scatter(sample[:, 0], sample[:, 1], 1.05 * rv.pdf(sample), c='k')
    
    ax.legend()
    ax.set_title("Gaussian sample and pdf")
    ax.set_xlim3d(-x_abs, x_abs)
    ax.set_ylim3d(-y_abs, y_abs)
    ax.set_zlim3d(0, 1)
    
    plt.show()
    

提交回复
热议问题