Efficiently create a density plot for high-density regions, points for sparse regions

后端 未结 4 1295
栀梦
栀梦 2020-12-13 22:01

I need to make a plot that functions like a density plot for high-density regions on the plot, but below some threshold uses individual points. I couldn\'t find any existing

4条回答
  •  醉话见心
    2020-12-13 22:39

    For the record, here is the result of a new attempt using scipy.stats.gaussian_kde rather than a 2D histogram. One could envision different combinations of color meshing and contouring depending on the purpose.

    import numpy as np
    from matplotlib import pyplot as plt
    from scipy.stats import gaussian_kde
    
    # parameters
    npts = 5000         # number of sample points
    bins = 100          # number of bins in density maps
    threshold = 0.01    # density threshold for scatter plot
    
    # initialize figure
    fig, ax = plt.subplots()
    
    # create a random dataset
    x1, y1 = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], npts/2).T
    x2, y2 = np.random.multivariate_normal([4, 4], [[4, 0], [0, 1]], npts/2).T
    x = np.hstack((x1, x2))
    y = np.hstack((y1, y2))
    points = np.vstack([x, y])
    
    # perform kernel density estimate
    kde = gaussian_kde(points)
    z = kde(points)
    
    # mask points above density threshold
    x = np.ma.masked_where(z > threshold, x)
    y = np.ma.masked_where(z > threshold, y)
    
    # plot unmasked points
    ax.scatter(x, y, c='black', marker='.')
    
    # get bounds from axes
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    
    # prepare grid for density map
    xedges = np.linspace(xmin, xmax, bins)
    yedges = np.linspace(ymin, ymax, bins)
    xx, yy = np.meshgrid(xedges, yedges)
    gridpoints = np.array([xx.ravel(), yy.ravel()])
    
    # compute density map
    zz = np.reshape(kde(gridpoints), xx.shape)
    
    # plot density map
    im = ax.imshow(zz, cmap='CMRmap_r', interpolation='nearest',
                   origin='lower', extent=[xmin, xmax, ymin, ymax])
    
    # plot threshold contour
    cs = ax.contour(xx, yy, zz, levels=[threshold], colors='black')
    
    # show
    fig.colorbar(im)
    plt.show()
    

    Smooth scatter plot

提交回复
热议问题