How to extract best parameters from a CrossValidatorModel

后端 未结 11 716
轮回少年
轮回少年 2020-12-13 02:31

I want to find the parameters of ParamGridBuilder that make the best model in CrossValidator in Spark 1.4.x,

In Pipeline Example in Spark documentation,

11条回答
  •  伪装坚强ぢ
    2020-12-13 02:53

    One method to get a proper ParamMap object is to use CrossValidatorModel.avgMetrics: Array[Double] to find the argmax ParamMap:

    implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) {
      def bestEstimatorParamMap: ParamMap = {
        cvModel.getEstimatorParamMaps
               .zip(cvModel.avgMetrics)
               .maxBy(_._2)
               ._1
      }
    }
    

    When run on the CrossValidatorModel trained in the Pipeline Example you cited gives:

    scala> println(cvModel.bestEstimatorParamMap)
    {
       hashingTF_2b0b8ccaeeec-numFeatures: 100,
       logreg_950a13184247-regParam: 0.1
    }
    

提交回复
热议问题