Spark Multiclass Classification Example

前端 未结 2 1580
渐次进展
渐次进展 2020-12-25 09:35

Do you guys know where can I find examples of multiclass classification in Spark. I spent a lot of time searching in books and in the web, and so far I just know that it is

相关标签:
2条回答
  • 2020-12-25 10:00

    Are you using Spark 1.6 rather than Spark 2.1? I think the problem is that in spark 2.1 the transform method returns a dataset, which can be implicitly converted to a typed RDD, where as prior to that, it returns a data frame or row.

    Try as a diagnostic specifying the return type of the transform function as RDD[LabeledPoint] and see if you get the same error.

    0 讨论(0)
  • 2020-12-25 10:09

    ML

    (Recommended in Spark 2.0+)

    We'll use the same data as in the MLlib below. There are two basic options. If Estimator supports multilclass classification out-of-the-box (for example random forest) you can use it directly:

    val trainRawDf = trainRaw.toDF
    
    import org.apache.spark.ml.feature.{Tokenizer, CountVectorizer, StringIndexer}
    import org.apache.spark.ml.Pipeline
    
    import org.apache.spark.ml.classification.RandomForestClassifier
    
    val transformers = Array(
      new StringIndexer().setInputCol("group").setOutputCol("label"),
      new Tokenizer().setInputCol("text").setOutputCol("tokens"),
      new CountVectorizer().setInputCol("tokens").setOutputCol("features")
    )
    
    
    val rf = new RandomForestClassifier() 
      .setLabelCol("label")
      .setFeaturesCol("features")
    
    val model = new Pipeline().setStages(transformers :+ rf).fit(trainRawDf)
    
    model.transform(trainRawDf)
    

    If model supports only binary classification (logistic regression) and extends o.a.s.ml.classification.Classifier you can use one-vs-rest strategy:

    import org.apache.spark.ml.classification.OneVsRest
    import org.apache.spark.ml.classification.LogisticRegression
    
    val lr = new LogisticRegression() 
      .setLabelCol("label")
      .setFeaturesCol("features")
    
    val ovr = new OneVsRest().setClassifier(lr)
    
    val ovrModel = new Pipeline().setStages(transformers :+ ovr).fit(trainRawDf)
    

    MLLib

    According to the official documentation at this moment (MLlib 1.6.0) following methods support multiclass classification:

    • logistic regression,
    • decision trees,
    • random forests,
    • naive Bayes

    At least some of the examples use multiclass classification:

    • Naive Bayes example - 3 classes
    • Logistic regression - 10 classes for classifier although only 2 in the example data

    General framework, ignoring method specific arguments, is pretty much the same as for all the other methods in MLlib. You have to pre-processes your input to create either data frame with columns representing label and features:

    root
     |-- label: double (nullable = true)
     |-- features: vector (nullable = true)
    

    or RDD[LabeledPoint].

    Spark provides broad range of useful tools designed to facilitate this process including Feature Extractors and Feature Transformers and pipelines.

    You'll find a rather naive example of using Random Forest below.

    First lets import required packages and create dummy data:

    import sqlContext.implicits._
    import org.apache.spark.ml.feature.{HashingTF, Tokenizer} 
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.ml.feature.StringIndexer
    import org.apache.spark.mllib.tree.RandomForest
    import org.apache.spark.mllib.tree.model.RandomForestModel
    import org.apache.spark.mllib.linalg.{Vectors, Vector}
    import org.apache.spark.mllib.evaluation.MulticlassMetrics
    import org.apache.spark.sql.Row
    import org.apache.spark.rdd.RDD
    
    case class LabeledRecord(group: String, text: String)
    
    val trainRaw = sc.parallelize(
        LabeledRecord("foo", "foo v a y b  foo") ::
        LabeledRecord("bar", "x bar y bar v") ::
        LabeledRecord("bar", "x a y bar z") ::
        LabeledRecord("foobar", "foo v b bar z") ::
        LabeledRecord("foo", "foo x") ::
        LabeledRecord("foobar", "z y x foo a b bar v") ::
        Nil
    )
    

    Now let's define required transformers and process train Dataset:

    // Tokenizer to process text fields
    val tokenizer = new Tokenizer()
        .setInputCol("text")
        .setOutputCol("words")
    
    // HashingTF to convert tokens to the feature vector
    val hashingTF = new HashingTF()
        .setInputCol("words")
        .setOutputCol("features")
        .setNumFeatures(10)
    
    // Indexer to convert String labels to Double
    val indexer = new StringIndexer()
        .setInputCol("group")
        .setOutputCol("label")
        .fit(trainRaw.toDF)
    
    
    def transfom(rdd: RDD[LabeledRecord]) = {
        val tokenized = tokenizer.transform(rdd.toDF)
        val hashed = hashingTF.transform(tokenized)
        val indexed = indexer.transform(hashed)
        indexed
            .select($"label", $"features")
            .map{case Row(label: Double, features: Vector) =>
                LabeledPoint(label, features)}
    }
    
    val train: RDD[LabeledPoint] = transfom(trainRaw)
    

    Please note that indexer is "fitted" on the train data. It simply means that categorical values used as the labels are converted to doubles. To use classifier on a new data you have to transform it first using this indexer.

    Next we can train RF model:

    val numClasses = 3
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 10
    val featureSubsetStrategy = "auto"
    val impurity = "gini"
    val maxDepth = 4
    val maxBins = 16
    
    val model = RandomForest.trainClassifier(
        train, numClasses, categoricalFeaturesInfo, 
        numTrees, featureSubsetStrategy, impurity,
        maxDepth, maxBins
    )
    

    and finally test it:

    val testRaw = sc.parallelize(
        LabeledRecord("foo", "foo  foo z z z") ::
        LabeledRecord("bar", "z bar y y v") ::
        LabeledRecord("bar", "a a  bar a z") ::
        LabeledRecord("foobar", "foo v b bar z") ::
        LabeledRecord("foobar", "a foo a bar") ::
        Nil
    )
    
    val test: RDD[LabeledPoint] = transfom(testRaw)
    
    val predsAndLabs = test.map(lp => (model.predict(lp.features), lp.label))
    val metrics = new MulticlassMetrics(predsAndLabs)
    
    metrics.precision
    metrics.recall
    
    0 讨论(0)
提交回复
热议问题