transform entire axes (or scatter plot) in matplotlib

前端 未结 1 1141
小鲜肉
小鲜肉 2020-12-21 02:58

I am plotting changes in mean and variance of some data with the following code

import matplotlib.pyplot as pyplot
import numpy

vis_mv(data, ax = None):
            


        
相关标签:
1条回答
  • 2020-12-21 03:35

    Unfortunately the PathCollection does not have a .set_offset_transform() method, but one can access the private _transOffset attribute and set the rotating transformation to it.

    import matplotlib.pyplot as plt
    from matplotlib.transforms import Affine2D
    from matplotlib.collections import PathCollection
    import numpy as np; np.random.seed(3)
    
    def vis_mv(data, ax = None):
        if ax is None: ax = plt.gca()
        cmap = plt.get_cmap()
        colors = cmap(np.linspace(0, 1, len(data)))
    
        xs = np.arange(len(data)) + 1
        means = np.array([ np.mean(x) for x in data ])
        varis = np.array([ np.var(x) for x in data ])
        vlim = max(1, np.amax(varis))
    
        # variance
        ax.imshow([[0.,1.],[0.,1.]],
            cmap = cmap, interpolation = 'bicubic',
            extent = (1, len(data), -vlim, vlim), aspect = 'auto'  )
        ax.fill_between(xs, -vlim, -varis, color = 'white')
        ax.fill_between(xs, varis, vlim, color = 'white')
    
        # mean
        ax.plot(xs, means, color = 'white', zorder = 1)
        ax.scatter(xs, means, color = colors, edgecolor = 'white', zorder = 2)
    
        return ax
    
    data = np.random.normal(size=(9, 9))
    ax  = vis_mv(data)
    
    
    r = Affine2D().rotate_deg(90)
    
    for x in ax.images + ax.lines + ax.collections:
        trans = x.get_transform()
        x.set_transform(r+trans)
        if isinstance(x, PathCollection):
            transoff = x.get_offset_transform()
            x._transOffset = r+transoff
    
    old = ax.axis()
    ax.axis(old[2:4] + old[0:2])
    
    
    plt.show()
    

    0 讨论(0)
提交回复
热议问题