RandomForestClassifier was given input with invalid label column error in Apache Spark

 ̄綄美尐妖づ 提交于 2019-12-01 05:59:43
zero323

RandomForestClassifier, same as many other ML algorithms, require specific metadata to be set on the label column and labels values to be integral values from [0, 1, 2 ..., #classes) represented as doubles. Typically this is handled by an upstream Transformers like StringIndexer. Since you convert labels manually metadata fields are not set and classifier cannot confirm that these requirements are satisfied.

val df = Seq(
  (0.0, Vectors.dense(1, 0, 0, 0)),
  (1.0, Vectors.dense(0, 1, 0, 0)),
  (2.0, Vectors.dense(0, 0, 1, 0)),
  (2.0, Vectors.dense(0, 0, 0, 1))
).toDF("label", "features")

val rf = new RandomForestClassifier()
  .setFeaturesCol("features")
  .setNumTrees(5)

rf.setLabelCol("label").fit(df)
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ...

You can either re-encode label column using StringIndexer:

import org.apache.spark.ml.feature.StringIndexer

val indexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("label_idx")
  .fit(df)

rf.setLabelCol("label_idx").fit(indexer.transform(df))

or set required metadata manually:

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0", "1.0", "2.0")
  .toMetadata

rf.setLabelCol("label_meta").fit(
  df.withColumn("label_meta", $"label".as("", meta))
)

Note:

Labels created using StringIndexer depend on the frequency not value:

indexer.labels
// Array[String] = Array(2.0, 0.0, 1.0)

PySpark:

In Python metadata fields can be set directly on the schema:

from pyspark.sql.types import StructField, DoubleType

StructField(
    "label", DoubleType(), False,
    {"ml_attr": {
        "name": "label",
        "type": "nominal", 
        "vals": ["0.0", "1.0", "2.0"]
    }}
)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!