How to plot correlation heatmap when using pyspark+databricks

寵の児 提交于 2020-07-06 20:22:10

问题


I am studying pyspark in databricks. I want to generate a correlation heatmap. Let's say this is my data:

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              ['col1','col2','col3'])

And this is my code:

import pyspark
from pyspark.sql import SparkSession
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from ggplot import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.mllib.stat import Statistics

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              ['col1','col2','col3'])
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                            outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)
matrix.collect()[0]["pearson({})".format(vector_col)].values

Until here, I can get the correlation matrix. The result looks like:

Now my problems are:

  1. How to transfer matrix to data frame? I have tried the methods of How to convert DenseMatrix to spark DataFrame in pyspark? and How to get correlation matrix values pyspark. But it does not work for me.
  2. How to generate a correlation heatmap which looks like:

Because I just studied pyspark and databricks. ggplot or matplotlib are both OK for my problem.


回答1:


I think the point where you get confused is:

matrix.collect()[0]["pearson({})".format(vector_col)].values

Calling .values of a densematrix gives you a list of all values, but what you are actually looking for is a list of list representing correlation matrix.

import matplotlib.pyplot as plt
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation

columns = ['col1','col2','col3']

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              columns)
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                            outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)

Until now it was basically your code. Instead of calling .values you should use .toArray().tolist() to get a list of lists representing the correlation matrix:

matrix = Correlation.corr(myGraph_vector, vector_col).collect()[0][0]
corrmatrix = matrix.toArray().tolist()
print(corrmatrix)

Output:

[[1.0, 0.9582184104641529, 0.9780872729407004], [0.9582184104641529, 1.0, 0.8776695567739841], [0.9780872729407004, 0.8776695567739841, 1.0]]

The advantage of this approach is that you can turn a list of lists easily into a dataframe:

df = spark.createDataFrame(corrmatrix,columns)
df.show()

Output:

+------------------+------------------+------------------+ 
|              col1|              col2|              col3| 
+------------------+------------------+------------------+ 
|               1.0|0.9582184104641529|0.9780872729407004|
|0.9582184104641529|               1.0|0.8776695567739841| 
|0.9780872729407004|0.8776695567739841|               1.0|  
+------------------+------------------+------------------+

To answer your second question. Just one of the many solutions to plot a heatmap (like this or this even better with seaborn).

def plot_corr_matrix(correlations,attr,fig_no):
    fig=plt.figure(fig_no)
    ax=fig.add_subplot(111)
    ax.set_title("Correlation Matrix for Specified Attributes")
    ax.set_xticklabels(['']+attr)
    ax.set_yticklabels(['']+attr)
    cax=ax.matshow(correlations,vmax=1,vmin=-1)
    fig.colorbar(cax)
    plt.show()

plot_corr_matrix(corrmatrix, columns, 234)


来源:https://stackoverflow.com/questions/55546467/how-to-plot-correlation-heatmap-when-using-pysparkdatabricks

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!