Spark MultilayerPerceptronClassifier Class Probabilities

≡放荡痞女 提交于 2019-12-24 20:54:55

问题


I am an experienced Python programmer trying to transition some Python code to Spark for a classification task. This is my first time working in Spark/Scala.

In Python, both Keras/tensorflow and sci-kit Learn neural networks do a great job on the multi-class classification and I'm able to easily return the top 3 most probable classes along with probabilities which are key to this project.

I have been generally successful in moving the code to Spark (Scala) and I'm able to generate the correct predictions but I have not been able to find a way to return probabilities for the top predicted classes from the MultilayerPerceptronClassifier in MLlib.

The closest solution I found was in this post: How to get classification probabilities from MultilayerPerceptronClassifier? However, I'm not able to get the solution in the post to work either because it's missing a key piece of code or I'm too new to Scala (probably the latter) to make the needed adjustments.

Has anyone solved this problem?

These are the current versions in my environment. Spark version: 2.1.1 Scala version: 2.11.8

Thanks for your help,

RKB


回答1:


If you carefully take a look at the results of MultilayerPerceptronClassificationModel.transform (model and test as defined in the example pipeline in the official documentation)

val result = model.transform(test)

result.printSchema
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

you'll see they contain probability column.

It is stored as o.a.s.ml.linalg.Vector column:

result.select($"probability").show(3, false)
+---------------------------------------------------+
|probability                                        |
+---------------------------------------------------+
|[2.630203838780848E-29,1.7323171642231641E-19,1.0] |
|[1.0,1.448487547623119E-121,4.530084532282489E-44] |
|[1.0,5.157808976162274E-122,2.5702890543589884E-44]|
+---------------------------------------------------+
only showing top 3 rows

and can be accessed using standard methods.

This feature is available since Spark 2.3 (SPARK-12664 Expose probability, rawPrediction in MultilayerPerceptronClassificationModel).



来源:https://stackoverflow.com/questions/54545639/spark-multilayerperceptronclassifier-class-probabilities

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!