How to cross validate RandomForest model?

后端 未结 2 584
广开言路
广开言路 2020-12-09 03:44

I want to evaluate a random forest being trained on some data. Is there any utility in Apache Spark to do the same or do I have to perform cross validation manually?

相关标签:
2条回答
  • 2020-12-09 04:38

    To build on zero323's great answer using Random Forest Classifier, here is a similar example for Random Forest Regressor:

    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
    import org.apache.spark.ml.regression.RandomForestRegressor // CHANGED
    import org.apache.spark.ml.evaluation.RegressionEvaluator // CHANGED
    import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer}
    
    val numFolds = ??? // Integer
    val data = ??? // DataFrame
    
    // Training (80%) and test data (20%)
    val Array(train, test) = data.randomSplit(Array(0.8,0.2))
    val featuresCols = data.columns
    val va = new VectorAssembler()
    va.setInputCols(featuresCols)
    va.setOutputCol("rawFeatures")
    val vi = new VectorIndexer()
    vi.setInputCol("rawFeatures")
    vi.setOutputCol("features")
    vi.setMaxCategories(5)
    val regressor = new RandomForestRegressor()
    regressor.setLabelCol("events")
    
    val metric = "rmse"
    val evaluator = new RegressionEvaluator()
      .setLabelCol("events")
      .setPredictionCol("prediction")
      //     "rmse" (default): root mean squared error
      //     "mse": mean squared error
      //     "r2": R2 metric
      //     "mae": mean absolute error 
      .setMetricName(metric) 
    
    val paramGrid = new ParamGridBuilder().build()
    val cv = new CrossValidator()
      .setEstimator(regressor)
      .setEvaluator(evaluator) 
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(numFolds)
    
    val model = cv.fit(train) // train: DataFrame
    val predictions = model.transform(test)
    predictions.show
    val rmse = evaluator.evaluate(predictions)
    println(rmse)
    

    Evaluator metric source: https://spark.apache.org/docs/latest/api/scala/#org.apache.spark.ml.evaluation.RegressionEvaluator

    0 讨论(0)
  • 2020-12-09 04:43

    ML provides CrossValidator class which can be used to perform cross-validation and parameter search. Assuming your data is already preprocessed you can add cross-validation as follows:

    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    
    // [label: double, features: vector]
    trainingData org.apache.spark.sql.DataFrame = ??? 
    val nFolds: Int = ???
    val numTrees: Int = ???
    val metric: String = ???
    
    val rf = new RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setNumTrees(numTrees)
    
    val pipeline = new Pipeline().setStages(Array(rf)) 
    
    val paramGrid = new ParamGridBuilder().build() // No parameter search
    
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
      .setMetricName(metric) 
    
    val cv = new CrossValidator()
      // ml.Pipeline with ml.classification.RandomForestClassifier
      .setEstimator(pipeline)
      // ml.evaluation.MulticlassClassificationEvaluator
      .setEvaluator(evaluator) 
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(nFolds)
    
    val model = cv.fit(trainingData) // trainingData: DataFrame
    

    Using PySpark:

    from pyspark.ml import Pipeline
    from pyspark.ml.classification import RandomForestClassifier
    from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
    from pyspark.ml.evaluation import MulticlassClassificationEvaluator
    
    trainingData = ... # DataFrame[label: double, features: vector]
    numFolds = ... # Integer
    
    rf = RandomForestClassifier(labelCol="label", featuresCol="features")
    evaluator = MulticlassClassificationEvaluator() # + other params as in Scala    
    
    pipeline = Pipeline(stages=[rf])
    paramGrid = (ParamGridBuilder. 
        .addGrid(rf.numTrees, [3, 10])
        .addGrid(...)  # Add other parameters
        .build())
    
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator=evaluator,
        numFolds=numFolds)
    
    model = crossval.fit(trainingData)
    
    0 讨论(0)
提交回复
热议问题