SPARK, ML, Tuning, CrossValidator: access the metrics

早过忘川 提交于 2019-12-05 20:17:31

问题


In order to build a NaiveBayes multiclass classifier, I am using a CrossValidator to select the best parameters in my pipeline:

val cv = new CrossValidator()
        .setEstimator(pipeline)
        .setEstimatorParamMaps(paramGrid)
        .setEvaluator(new MulticlassClassificationEvaluator)
        .setNumFolds(10)

val cvModel = cv.fit(trainingSet)

The pipeline contains usual transformers and estimators in the following order: Tokenizer, StopWordsRemover, HashingTF, IDF and finally the NaiveBayes.

Is it possible to access the metrics calculated for best model?

Ideally, I would like to access the metrics of all models to see how changing the parameters is changing the quality of the classification. But for the moment, the best model is good enough.

FYI, I am using Spark 1.6.0


回答1:


Here's how I do it:

val pipeline = new Pipeline()
  .setStages(Array(tokenizer, stopWordsFilter, tf, idf, word2Vec, featureVectorAssembler, categoryIndexerModel, classifier, categoryReverseIndexer))

...

val paramGrid = new ParamGridBuilder()
  .addGrid(tf.numFeatures, Array(10, 100))
  .addGrid(idf.minDocFreq, Array(1, 10))
  .addGrid(word2Vec.vectorSize, Array(200, 300))
  .addGrid(classifier.maxDepth, Array(3, 5))
  .build()

paramGrid.size // 16 entries

...

// Print the average metrics per ParamGrid entry
val avgMetricsParamGrid = crossValidatorModel.avgMetrics

// Combine with paramGrid to see how they affect the overall metrics
val combined = paramGrid.zip(avgMetricsParamGrid)

...

val bestModel = crossValidatorModel.bestModel.asInstanceOf[PipelineModel]

// Explain params for each stage
val bestHashingTFNumFeatures = bestModel.stages(2).asInstanceOf[HashingTF].explainParams
val bestIDFMinDocFrequency = bestModel.stages(3).asInstanceOf[IDFModel].explainParams
val bestWord2VecVectorSize = bestModel.stages(4).asInstanceOf[Word2VecModel].explainParams
val bestDecisionTreeDepth = bestModel.stages(7).asInstanceOf[DecisionTreeClassificationModel].explainParams



回答2:


 cvModel.avgMetrics

works in pyspark 2.2.0



来源:https://stackoverflow.com/questions/34678818/spark-ml-tuning-crossvalidator-access-the-metrics

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!