How to plot multi-color line if x-axis is date time index of pandas

前端 未结 2 1264
闹比i
闹比i 2020-11-28 11:05

I am trying to plot a multi-color line using pandas series. I know matplotlib.collections.LineCollection will sharply promote the efficiency. But LineCollection

2条回答
  •  隐瞒了意图╮
    2020-11-28 11:26

    ImportanceOfBeingErnest's is a very good answer and saved me many hours of work. I want to share how I used above answer to change color based on signal from a pandas DataFrame.

    import matplotlib.dates as mdates
    # import matplotlib.pyplot as plt
    # import numpy as np
    # import pandas as pd
    from matplotlib.collections import LineCollection
    from matplotlib.colors import ListedColormap, BoundaryNorm
    

    Make test DataFrame

    equity = pd.DataFrame(index=pd.date_range('20150701', periods=150))
    equity['price'] = np.random.uniform(low=15500, high=18500, size=(150,))
    equity['signal'] = 0
    equity.signal[15:45] = 1
    equity.signal[60:90] = -1
    equity.signal[105:135] = 1
    
    # Create a colormap for crimson, limegreen and gray and a norm to color
    # signal = -1 crimson, signal = 1 limegreen, and signal = 0 lightgray
    cmap = ListedColormap(['crimson', 'lightgray', 'limegreen'])
    norm = BoundaryNorm([-1.5, -0.5, 0.5, 1.5], cmap.N)
    
    # Convert dates to numbers
    inxval = mdates.date2num(equity.index.to_pydatetime())
    
    # Create a set of line segments so that we can color them individually
    # This creates the points as a N x 1 x 2 array so that we can stack points
    # together easily to get the segments. The segments array for line collection
    # needs to be numlines x points per line x 2 (x and y)
    points = np.array([inxval, equity.price.values]).T.reshape(-1,1,2)
    segments = np.concatenate([points[:-1],points[1:]], axis=1)
    
    # Create the line collection object, setting the colormapping parameters.
    # Have to set the actual values used for colormapping separately.
    lc = LineCollection(segments, cmap=cmap, norm=norm, linewidth=2)
    
    # Set color using signal values
    lc.set_array(equity.signal.values)
    
    fig, ax = plt.subplots()
    fig.autofmt_xdate()
    
    # Add collection to axes
    ax.add_collection(lc)
    
    plt.xlim(equity.index.min(), equity.index.max())
    plt.ylim(equity.price.min(), equity.price.max())
    plt.tight_layout()
    
    # plt.savefig('test_mline.png', dpi=150)
    plt.show()
    

提交回复
热议问题