Spark migrate sql window function to RDD for better performance

后端 未结 2 1841
盖世英雄少女心
盖世英雄少女心 2021-01-01 04:54

A function should be executed for multiple columns in a data frame

def handleBias(df: DataFrame, colName: String, target: String = target) = {
    val w1 = W         


        
2条回答
  •  南方客
    南方客 (楼主)
    2021-01-01 05:40

    The main point here is to avoid unnecessary shuffles. Right now your code shuffles twice for each column you want to include and the resulting data layout cannot be reused between columns.

    For simplicity I assume that target is always binary ({0, 1}) and all remaining columns you use are of StringType. Furthermore I assume that the cardinality of the columns is low enough for the results to be grouped and handled locally. You can adjust these methods to handle other cases but it requires more work.

    RDD API

    • Reshape data from wide to long:

      import org.apache.spark.sql.functions._
      
      val exploded = explode(array(
        (columnsToDrop ++ columnsToCode).map(c => 
          struct(lit(c).alias("k"), col(c).alias("v"))): _*
      )).alias("level")
      
      val long = df.select(exploded, $"TARGET")
      
    • aggregateByKey, reshape and collect:

      import org.apache.spark.util.StatCounter
      
      val lookup = long.as[((String, String), Int)].rdd
        // You can use prefix partitioner (one that depends only on _._1)
        // to avoid reshuffling for groupByKey
        .aggregateByKey(StatCounter())(_ merge _, _ merge _)
        .map { case ((c, v), s) => (c, (v, s)) }
        .groupByKey
        .mapValues(_.toMap)
        .collectAsMap
      
    • You can use lookup to get statistics for individual columns and levels. For example:

      lookup("col1")("A")
      
      org.apache.spark.util.StatCounter = 
        (count: 3, mean: 0.666667, stdev: 0.471405, max: 1.000000, min: 0.000000)
      

      Gives you data for col1, level A. Based on the binary TARGET assumption this information is complete (you get count / fractions for both classes).

      You can use lookup like this to generate SQL expressions or pass it to udf and apply it on individual columns.

    DataFrame API

    • Convert data to long as for RDD API.
    • Compute aggregates based on levels:

      val stats = long
        .groupBy($"level.k", $"level.v")
        .agg(mean($"TARGET"), sum($"TARGET"))
      
    • Depending on your preferences you can reshape this to enable efficient joins or convert to a local collection and similarly to the RDD solution.

提交回复
热议问题