Plot correlation matrix using pandas

前端 未结 12 697
渐次进展
渐次进展 2020-11-30 16:23

I have a data set with huge number of features, so analysing the correlation matrix has become very difficult. I want to plot a correlation matrix which we get using d

相关标签:
12条回答
  • 2020-11-30 16:38

    Try this function, which also displays variable names for the correlation matrix:

    def plot_corr(df,size=10):
        '''Function plots a graphical correlation matrix for each pair of columns in the dataframe.
    
        Input:
            df: pandas DataFrame
            size: vertical and horizontal size of the plot'''
    
        corr = df.corr()
        fig, ax = plt.subplots(figsize=(size, size))
        ax.matshow(corr)
        plt.xticks(range(len(corr.columns)), corr.columns);
        plt.yticks(range(len(corr.columns)), corr.columns);
    
    0 讨论(0)
  • 2020-11-30 16:39

    You can use pyplot.matshow() from matplotlib:

    import matplotlib.pyplot as plt
    
    plt.matshow(dataframe.corr())
    plt.show()
    

    Edit:

    In the comments was a request for how to change the axis tick labels. Here's a deluxe version that is drawn on a bigger figure size, has axis labels to match the dataframe, and a colorbar legend to interpret the color scale.

    I'm including how to adjust the size and rotation of the labels, and I'm using a figure ratio that makes the colorbar and the main figure come out the same height.

    f = plt.figure(figsize=(19, 15))
    plt.matshow(df.corr(), fignum=f.number)
    plt.xticks(range(df.shape[1]), df.columns, fontsize=14, rotation=45)
    plt.yticks(range(df.shape[1]), df.columns, fontsize=14)
    cb = plt.colorbar()
    cb.ax.tick_params(labelsize=14)
    plt.title('Correlation Matrix', fontsize=16);
    

    0 讨论(0)
  • 2020-11-30 16:40

    You can observe the relation between features either by drawing a heat map from seaborn or scatter matrix from pandas.

    Scatter Matrix:

    pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');
    

    If you want to visualize each feature's skewness as well - use seaborn pairplots.

    sns.pairplot(dataframe)
    

    Sns Heatmap:

    import seaborn as sns
    
    f, ax = pl.subplots(figsize=(10, 8))
    corr = dataframe.corr()
    sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True),
                square=True, ax=ax)
    

    The output will be a correlation map of the features. i.e. see the below example.

    The correlation between grocery and detergents is high. Similarly:

    Pdoducts With High Correlation:
    1. Grocery and Detergents.
    Products With Medium Correlation:
    1. Milk and Grocery
    2. Milk and Detergents_Paper
    Products With Low Correlation:
    1. Milk and Deli
    2. Frozen and Fresh.
    3. Frozen and Deli.

    From Pairplots: You can observe same set of relations from pairplots or scatter matrix. But from these we can say that whether the data is normally distributed or not.

    Note: The above is same graph taken from the data, which is used to draw heatmap.

    0 讨论(0)
  • 2020-11-30 16:41

    You can use imshow() method from matplotlib

    import pandas as pd
    import matplotlib.pyplot as plt
    plt.style.use('ggplot')
    
    plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
    plt.colorbar()
    tick_marks = [i for i in range(len(X.columns))]
    plt.xticks(tick_marks, X.columns, rotation='vertical')
    plt.yticks(tick_marks, X.columns)
    plt.show()
    
    0 讨论(0)
  • 2020-11-30 16:44

    If you dataframe is df you can simply use:

    import matplotlib.pyplot as plt
    import seaborn as sns
    
    plt.figure(figsize=(15, 10))
    sns.heatmap(df.corr(), annot=True)
    
    0 讨论(0)
  • 2020-11-30 16:48

    Seaborn's heatmap version:

    import seaborn as sns
    corr = dataframe.corr()
    sns.heatmap(corr, 
                xticklabels=corr.columns.values,
                yticklabels=corr.columns.values)
    
    0 讨论(0)
提交回复
热议问题