pyspark extract ROC curve?

人盡茶涼 提交于 2019-12-14 03:56:34

问题


Is there a way to get the points on an ROC curve from Spark ML in pyspark? In the documentation I see an example for Scala but not python: https://spark.apache.org/docs/2.1.0/mllib-evaluation-metrics.html

Is that right? I can certainly think of ways to implement it but I have to imagine it’s faster if there’s a pre-built function. I’m working with 3 million scores and a few dozen models so speed matters.

Thanks!


回答1:


As long as the ROC curve is a plot of FPR against TPR, you can extract the needed values as following:

your_model.summary.roc.select('FPR').collect()
your_model.summary.roc.select('TPR').collect())

Where your_model could be for example a model you got from something like this:

from pyspark.ml.classification import LogisticRegression
log_reg = LogisticRegression()
your_model = log_reg.fit(df)

Now you should just plot FPR against TPR, using for example matplotlib.

P.S.

Here is a complete example for plotting ROC curve using a model named your_model (and anything else!). I've also plot a reference "random guess" line inside the ROC plot.

import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
plt.plot([0, 1], [0, 1], 'r--')
plt.plot(your_model.summary.roc.select('FPR').collect(),
         your_model.summary.roc.select('TPR').collect())
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.show()



回答2:


For a more general solution that works for models besides Logistic Regression (like Decision Trees or Random Forest which lack a model summary) you can get the ROC curve using BinaryClassificationMetrics from Spark MLlib.

Note that the PySpark version doesn't implement all of the methods that the Scala version does, so you'll need to use the .call(name) function from JavaModelWrapper. It also seems that py4j doesn't support parsing scala.Tuple2 classes, so they have to be manually processed.

Example:

from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Scala version implements .roc() and .pr()
# Python: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
    def __init__(self, *args):
        super(CurveMetrics, self).__init__(*args)

    def _to_list(self, rdd):
        points = []
        # Note this collect could be inefficient for large datasets 
        # considering there may be one probability per datapoint (at most)
        # The Scala version takes a numBins parameter, 
        # but it doesn't seem possible to pass this from Python to Java
        for row in rdd.collect():
            # Results are returned as type scala.Tuple2, 
            # which doesn't appear to have a py4j mapping
            points += [(float(row._1()), float(row._2()))]
        return points

    def get_curve(self, method):
        rdd = getattr(self._java_model, method)().toJavaRDD()
        return self._to_list(rdd)

Usage:

import matplotlib.pyplot as plt

# Create a Pipeline estimator and fit on train DF, predict on test DF
model = estimator.fit(train)
predictions = model.transform(test)

# Returns as a list (false positive rate, true positive rate)
preds = predictions.select('label','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['label'])))
roc = CurveMetrics(preds).get_curve('roc')

plt.figure()
x_val = [x[0] for x in points]
y_val = [x[1] for x in points]
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(x_val, y_val)

BinaryClassificationMetrics in Scala implements several other useful methods as well:

metrics = CurveMetrics(preds)
metrics.get_curve('fMeasureByThreshold')
metrics.get_curve('precisionByThreshold')
metrics.get_curve('recallByThreshold')


来源:https://stackoverflow.com/questions/52847408/pyspark-extract-roc-curve

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!