Plotly: How to make a figure with multiple lines and shaded area for standard deviations?

后端 未结 2 1713
梦毁少年i
梦毁少年i 2020-12-19 16:46

How can I use Plotly to produce a line plot with a shaded standard deviation? I am trying to achieve something similar to seaborn.tsplot. Any help is appreciated.

相关标签:
2条回答
  • 2020-12-19 16:49

    The following approach is fully flexible with regards to the number of columns in a pandas dataframe and uses the default color cycle of plotly. If the number of lines exceed the number of colors, the colors will be re-used from the start. As of now px.colors.qualitative.Plotly can be replaced with any hex color sequence that you can find using px.colors.qualitative:

    Alphabet = ['#AA0DFE', '#3283FE', '#85660D', '#782AB6', '#565656', '#1...
    Alphabet_r = ['#FA0087', '#FBE426', '#B00068', '#FC1CBF', '#C075A6', '...
    [...]
    

    Complete code:

    # imports
    import plotly.graph_objs as go
    import plotly.express as px
    import pandas as pd
    import numpy as np
    
    # sample data in a pandas dataframe
    np.random.seed(1)
    df=pd.DataFrame(dict(A=np.random.uniform(low=-1, high=2, size=25).tolist(),
                        B=np.random.uniform(low=-4, high=3, size=25).tolist(),
                        C=np.random.uniform(low=-1, high=3, size=25).tolist(),
                        ))
    df = df.cumsum()
    
    # define colors as a list 
    colors = px.colors.qualitative.Plotly
    
    # convert plotly hex colors to rgba to enable transparency adjustments
    def hex_rgba(hex, transparency):
        col_hex = hex.lstrip('#')
        col_rgb = list(int(col_hex[i:i+2], 16) for i in (0, 2, 4))
        col_rgb.extend([transparency])
        areacol = tuple(col_rgb)
        return areacol
    
    rgba = [hex_rgba(c, transparency=0.2) for c in colors]
    colCycle = ['rgba'+str(elem) for elem in rgba]
    
    # Make sure the colors run in cycles if there are more lines than colors
    def next_col(cols):
        while True:
            for col in cols:
                yield col
    line_color=next_col(cols=colCycle)
    
    # plotly  figure
    fig = go.Figure()
    
    # add line and shaded area for each series and standards deviation
    for i, col in enumerate(df):
        new_col = next(line_color)
        x = list(df.index.values+1)
        y1 = df[col]
        y1_upper = [(y + np.std(df[col])) for y in df[col]]
        y1_lower = [(y - np.std(df[col])) for y in df[col]]
        y1_lower = y1_lower[::-1]
    
        # standard deviation area
        fig.add_traces(go.Scatter(x=x+x[::-1],
                                    y=y1_upper+y1_lower,
                                    fill='tozerox',
                                    fillcolor=new_col,
                                    line=dict(color='rgba(255,255,255,0)'),
                                    showlegend=False,
                                    name=col))
    
        # line trace
        fig.add_traces(go.Scatter(x=x,
                                  y=y1,
                                  line=dict(color=new_col, width=2.5),
                                  mode='lines',
                                  name=col)
                                    )
    # set x-axis
    fig.update_layout(xaxis=dict(range=[1,len(df)]))
    
    fig.show()
    
    0 讨论(0)
  • 2020-12-19 17:06

    I was able to come up with something similar. I will post the code here to be used by someone else or for any suggestions for improvements.

    import matplotlib,random import plotly.graph_objects as go import numpy as np

    #random color generation in plotly
    hex_colors_dic = {}
    rgb_colors_dic = {}
    hex_colors_only = []
    for name, hex in matplotlib.colors.cnames.items():
        hex_colors_only.append(hex)
        hex_colors_dic[name] = hex
        rgb_colors_dic[name] = matplotlib.colors.to_rgb(hex)
    
    data = [[1, 3, 5, 4],
            [2, 3, 5, 4],
            [1, 1, 4, 5],
            [2, 3, 5, 4]]
    #calculating mean and standard deviation
    mean=np.mean(data,axis=0)
    std=np.std(data,axis=0)
    
    #draw figure
    fig = go.Figure()
    c = random.choice(hex_colors_only)
    fig.add_trace(go.Scatter(x=np.arange(4), y=mean+std,
                                         mode='lines',
                                         line=dict(color=c,width =0.1),
                                         name='upper bound'))
    fig.add_trace(go.Scatter(x=np.arange(4), y=mean,
                             mode='lines',
                             line=dict(color=c),
                             fill='tonexty',
                             name='mean'))
    fig.add_trace(go.Scatter(x=np.arange(4), y=mean-std,
                             mode='lines',
                             line=dict(color=c, width =0.1),
                             fill='tonexty',
                             name='lower bound'))
    fig.show()
    
    0 讨论(0)
提交回复
热议问题