Spark, ML, StringIndexer: handling unseen labels

♀尐吖头ヾ 提交于 2019-11-27 05:38:07

问题


My goal is to build a multicalss classifier.

I have built a pipeline for feature extraction and it includes as a first step a StringIndexer transformer to map each class name to a label, this label will be used in the classifier training step.

The pipeline is fitted the training set.

The test set has to be processed by the fitted pipeline in order to extract the same feature vectors.

Knowing that my test set files have the same structure of the training set. The possible scenario here is to face an unseen class name in the test set, in that case the StringIndexer will fail to find the label, and an exception will be raised.

Is there a solution for this case? or how can we avoid that from happening?


回答1:


With Spark 2.2 (released 7-2017) you are able to use the .setHandleInvalid("keep") option when creating the indexer. With this option, the indexer adds new indexes when he sees new labels. Note that with previous versions you also have the "skip" option, which makes the indexer ignore (remove) the rows with new labels.

val categoryIndexerModel = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("indexedCategory")
  .setHandleInvalid("keep") // options are "keep", "error" or "skip"



回答2:


There's a way around this in Spark 1.6.

Here's the jira: https://issues.apache.org/jira/browse/SPARK-8764

Here's an example:

val categoryIndexerModel = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("indexedCategory")
  .setHandleInvalid("skip") // new method.  values are "error" or "skip"

I started using this, but ended up going back to KrisP's 2nd bullet point about fitting this particular Estimator to the full dataset.

You'll need this later in the pipeline when you convert the IndexToString.

Here's the modified example:

val categoryIndexerModel = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("indexedCategory")
  .fit(itemsDF) // Fit the Estimator and create a Model (Transformer)

... do some kind of classification ...

val categoryReverseIndexer = new IndexToString()
  .setInputCol(classifier.getPredictionCol)
  .setOutputCol("predictedCategory")
  .setLabels(categoryIndexerModel.labels) // Use the labels from the Model



回答3:


No nice way to do it, I'm afraid. Either

  • filter out the test examples with unknown labels before applying StringIndexer
  • or fit StringIndexer to the union of train and test dataframe, so you are assured all labels are there
  • or transform the test example case with unknown label to a known label

Here is some sample code to perform above operations:

// get training labels from original train dataframe
val trainlabels = traindf.select(colname).distinct.map(_.getString(0)).collect  //Array[String]
// or get labels from a trained StringIndexer model
val trainlabels = simodel.labels 

// define an UDF on your dataframe that will be used for filtering
val filterudf = udf { label:String => trainlabels.contains(label)}

// filter out the bad examples 
val filteredTestdf = testdf.filter( filterudf(testdf(colname)))

// transform unknown value to some value, say "a"
val mapudf = udf { label:String => if (trainlabels.contains(label)) label else "a"}

// add a new column to testdf: 
val transformedTestdf = testdf.withColumn( "newcol", mapudf(testdf(colname)))



回答4:


In my case, I was running spark ALS on a large data set and the data was not available at all partitions so I had to cache() the data appropriately and it worked like a charm




回答5:


To me, ignoring the rows completely by setting an argument (https://issues.apache.org/jira/browse/SPARK-8764) is not really feasible way to solve the issue.

I ended up creating my own CustomStringIndexer transformer which will assign a new value for all new strings that were not encountered while training. You can also do this by changing the relevant portions of the spark feature code(just remove the if condition explicitly checking for this and make it return the length of the array instead) and recompile the jar.

Not really an easy fix, but it certainly is a fix.

I remember seeing a bug in JIRA to incorporate this as well: https://issues.apache.org/jira/browse/SPARK-17498

It is set to be released with Spark 2.2 though. Just have to wait I guess :S



来源:https://stackoverflow.com/questions/34681534/spark-ml-stringindexer-handling-unseen-labels

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!