Adding y=x to a matplotlib scatter plot if I haven't kept track of all the data points that went in

后端 未结 3 1198
迷失自我
迷失自我 2020-12-14 08:17

Here\'s some code that does scatter plot of a number of different series using matplotlib and then adds the line y=x:

import numpy as np, matplotlib.pyplot a         


        
3条回答
  •  伪装坚强ぢ
    2020-12-14 08:19

    You don't need to know anything about your data per se. You can get away with what your matplotlib Axes object will tell you about the data.

    See below:

    import numpy as np
    import matplotlib.pyplot as plt
    
    # random data 
    N = 37
    x = np.random.normal(loc=3.5, scale=1.25, size=N)
    y = np.random.normal(loc=3.4, scale=1.5, size=N)
    c = x**2 + y**2
    
    # now sort it just to make it look like it's related
    x.sort()
    y.sort()
    
    fig, ax = plt.subplots()
    ax.scatter(x, y, s=25, c=c, cmap=plt.cm.coolwarm, zorder=10)
    

    Here's the good part:

    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    
    # now plot both limits against eachother
    ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
    ax.set_aspect('equal')
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    fig.savefig('/Users/paul/Desktop/so.png', dpi=300)
    

    Et voilà

    enter image description here

提交回复
热议问题