How to extract best parameters from a CrossValidatorModel

后端 未结 11 729
轮回少年
轮回少年 2020-12-13 02:31

I want to find the parameters of ParamGridBuilder that make the best model in CrossValidator in Spark 1.4.x,

In Pipeline Example in Spark documentation,

11条回答
  •  借酒劲吻你
    2020-12-13 03:16

    This SO thread kinda answers the question.

    In a nutshell, you need to cast each object to its supposed-to-be class.

    For the case of CrossValidatorModel, the following is what I did:

    import org.apache.spark.ml.tuning.CrossValidatorModel
    import org.apache.spark.ml.PipelineModel
    import org.apache.spark.ml.regression.RandomForestRegressionModel
    
    // Load CV model from S3
    val inputModelPath = "s3://path/to/my/random-forest-regression-cv"
    val reloadedCvModel = CrossValidatorModel.load(inputModelPath)
    
    // To get the parameters of the best model
    (
        reloadedCvModel.bestModel
            .asInstanceOf[PipelineModel]
            .stages(1)
            .asInstanceOf[RandomForestRegressionModel]
            .extractParamMap()
    )
    

    In the example, my pipeline has two stages (a VectorIndexer and a RandomForestRegressor), so the stage index is 1 for my model.

提交回复
热议问题