How to transform a categorical variable in Spark into a set of columns coded as {0,1}?

后端 未结 4 732
终归单人心
终归单人心 2020-12-29 14:42

I\'m trying to perform a logistic regression (LogisticRegressionWithLBFGS) with Spark MLlib (with Scala) on a dataset which contains categorical variables. I discover Spark

4条回答
  •  梦谈多话
    2020-12-29 14:57

    If the categories can fit in the driver memory, here is my suggestion:

    import org.apache.spark.ml.feature.StringIndexer
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql._
    
    
    val df = Seq((0, "a"),(1, "b"),(2, "c"),(3, "a"),(4, "a"),(5, "c"),(6,"c"),(7,"d"),(8,"b"))
                .toDF("id", "category")
    val indexer = new StringIndexer()
                       .setInputCol("category")
                       .setOutputCol("categoryIndex")
                       .fit(df)
    
    val indexed = indexer.transform(df)
    
    val categoriesIndecies = indexed.select("category","categoryIndex").distinct
    val categoriesMap: scala.collection.Map[String,Double] = categoriesIndecies.map(x=>(x(0).toString,x(1).toString.toDouble)).collectAsMap()
    
    def getCategoryIndex(catMap: scala.collection.Map[String,Double], expectedValue: Double) = udf((columnValue: String) =>
    if (catMap(columnValue) == expectedValue) 1 else 0)
    
    
    val newDf:DataFrame =categoriesMap.keySet.toSeq.foldLeft[DataFrame](indexed)(
         (acc,c) => 
              acc.withColumn(c,getCategoryIndex(categoriesMap,categoriesMap(c))($"category"))
         )
    
    newDf.show
    
    
    +---+--------+-------------+---+---+---+---+
    | id|category|categoryIndex|  b|  d|  a|  c|
    +---+--------+-------------+---+---+---+---+
    |  0|       a|          0.0|  0|  0|  1|  0|
    |  1|       b|          2.0|  1|  0|  0|  0|
    |  2|       c|          1.0|  0|  0|  0|  1|
    |  3|       a|          0.0|  0|  0|  1|  0|
    |  4|       a|          0.0|  0|  0|  1|  0|
    |  5|       c|          1.0|  0|  0|  0|  1|
    |  6|       c|          1.0|  0|  0|  0|  1|
    |  7|       d|          3.0|  0|  1|  0|  0|
    |  8|       b|          2.0|  1|  0|  0|  0|
    +---+--------+-------------+---+---+---+---+
    

提交回复
热议问题