How to print best model params in pyspark pipeline

允我心安 提交于 2019-12-12 19:08:10

问题


This question is similar to this one. I would like to print the best model params after doing a TrainValidationSplit in pyspark. I cannot find the piece of text the other user uses to answer the question because I'm working on jupyter and the log dissapears from the terminal...

Part of the code is:

pca = PCA(inputCol = 'features')
dt = DecisionTreeRegressor(featuresCol=pca.getOutputCol(), 
                           labelCol="energy")
pipe = Pipeline(stages=[pca,dt])

paramgrid = ParamGridBuilder().addGrid(pca.k, range(1,50,2)).addGrid(dt.maxDepth, range(1,10,1)).build()

tvs = TrainValidationSplit(estimator = pipe, evaluator = RegressionEvaluator(
labelCol="energy", predictionCol="prediction", metricName="mae"), estimatorParamMaps = paramgrid, trainRatio = 0.66)

model = tvs.fit(wind_tr_va);

Thanks in advance.


回答1:


It follows indeed the same reasoning described in the answer about How to get the maxDepth from a Spark RandomForestRegressionModel given by @user6910411.

You'll need to patch the TrainValidationSplitModel, PCAModel and DecisionTreeRegressionModel as followed :

TrainValidationSplitModel.bestModel = (
    lambda self: self._java_obj.bestModel
)

PCAModel.getK = (
    lambda self: self._java_obj.getK()
)

DecisionTreeRegressionModel.getMaxDepth = (
    lambda self: self._java_obj.getMaxDepth()
)

Now you can use it to get the best model and extract k and maxDepth

bestModel = model.bestModel

bestModelK = bestModel.stages[0].getK()
bestModelMaxDepth = bestModel.stages[1].getMaxDepth()

PS: You can patch models to get specific parameters the same way described above.




回答2:


Even simpler (1-line), just refer to the JVM object of your model

    cvModel.bestModel.stages[-1]._java_obj.getMaxDepth()

Here you take your bestModel after cross-validation, call the JVM object of this model and extract maxDepth parameter using getMaxDepth()-method from the JVM object.

The list of all original JVM get-parameters can be found here https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html

Also, you can browse other get-parameters for other models and extract them referring to the original JVM object of any model

    <yourModel>.stages[<yourModelStage>]._java_obj.<getParameter>()

Hope it helps.



来源:https://stackoverflow.com/questions/41781529/how-to-print-best-model-params-in-pyspark-pipeline

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!