Can I plot a linear regression with datetimes on the x-axis with seaborn?

前端 未结 2 1043
旧巷少年郎
旧巷少年郎 2020-12-29 14:26

My DataFrame object looks like

            amount
date    
2014-01-06  1
2014-01-07  1
2014-01-08  4
2014-01-09  1
2014-01-14  1

I would li

2条回答
  •  再見小時候
    2020-12-29 14:41

    Since Seaborn has trouble with dates, I'm going to create a work-around. First, I'll make the Date column my index:

    # Make dataframe
    df = pd.DataFrame({'amount' : [1,
                                   1,
                                   4,
                                   1,
                                   1]},
                      index = ['2014-01-06',
                               '2014-01-07',
                               '2014-01-08',
                               '2014-01-09',
                               '2014-01-14'])
    

    Second, convert the index to pd.DatetimeIndex:

    # Make index pd.DatetimeIndex
    df.index = pd.DatetimeIndex(df.index)
    

    And replace the original with it:

    # Make new index
    idx = pd.date_range(df.index.min(), df.index.max())
    

    Third, reindex with the new index (idx):

    # Replace original index with idx
    df = df.reindex(index = idx)
    

    This will produce a new dataframe with NaN values for the dates you don't have data:

    Fourth, since Seaborn doesn't play nice with dates and regression lines I'll create a row count column that we can use as our x-axis:

    # Insert row count
    df.insert(df.shape[1],
              'row_count',
              df.index.value_counts().sort_index().cumsum())
    

    Fifth, we should now be able to plot a regression line using 'row_count' as our x variable and 'amount' as our y variable:

    # Plot regression using Seaborn
    fig = sns.regplot(data = df, x = 'row_count', y = 'amount')
    

    Sixth, if you would like the dates to be along the x-axis instead of the row_count you can set the x-tick labels to the index:

    # Change x-ticks to dates
    labels = [item.get_text() for item in fig.get_xticklabels()]
    
    # Set labels for 1:10 because labels has 11 elements (0 is the left edge, 11 is the right
    # edge) but our data only has 9 elements
    labels[1:10] = df.index.date
    
    # Set x-tick labels
    fig.set_xticklabels(labels)
    
    # Rotate the labels so you can read them
    plt.xticks(rotation = 45)
    
    # Change x-axis title
    plt.xlabel('date')
    
    plt.show();
    

    Hope this helps!

提交回复
热议问题