How to cross validate RandomForest model?

后端 未结 2 586
广开言路
广开言路 2020-12-09 03:44

I want to evaluate a random forest being trained on some data. Is there any utility in Apache Spark to do the same or do I have to perform cross validation manually?

2条回答
  •  情歌与酒
    2020-12-09 04:38

    To build on zero323's great answer using Random Forest Classifier, here is a similar example for Random Forest Regressor:

    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
    import org.apache.spark.ml.regression.RandomForestRegressor // CHANGED
    import org.apache.spark.ml.evaluation.RegressionEvaluator // CHANGED
    import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer}
    
    val numFolds = ??? // Integer
    val data = ??? // DataFrame
    
    // Training (80%) and test data (20%)
    val Array(train, test) = data.randomSplit(Array(0.8,0.2))
    val featuresCols = data.columns
    val va = new VectorAssembler()
    va.setInputCols(featuresCols)
    va.setOutputCol("rawFeatures")
    val vi = new VectorIndexer()
    vi.setInputCol("rawFeatures")
    vi.setOutputCol("features")
    vi.setMaxCategories(5)
    val regressor = new RandomForestRegressor()
    regressor.setLabelCol("events")
    
    val metric = "rmse"
    val evaluator = new RegressionEvaluator()
      .setLabelCol("events")
      .setPredictionCol("prediction")
      //     "rmse" (default): root mean squared error
      //     "mse": mean squared error
      //     "r2": R2 metric
      //     "mae": mean absolute error 
      .setMetricName(metric) 
    
    val paramGrid = new ParamGridBuilder().build()
    val cv = new CrossValidator()
      .setEstimator(regressor)
      .setEvaluator(evaluator) 
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(numFolds)
    
    val model = cv.fit(train) // train: DataFrame
    val predictions = model.transform(test)
    predictions.show
    val rmse = evaluator.evaluate(predictions)
    println(rmse)
    

    Evaluator metric source: https://spark.apache.org/docs/latest/api/scala/#org.apache.spark.ml.evaluation.RegressionEvaluator

提交回复
热议问题