A function should be executed for multiple columns in a data frame
def handleBias(df: DataFrame, colName: String, target: String = target) = {
val w1 = W
The main point here is to avoid unnecessary shuffles. Right now your code shuffles twice for each column you want to include and the resulting data layout cannot be reused between columns.
For simplicity I assume that target
is always binary ({0, 1}) and all remaining columns you use are of StringType
. Furthermore I assume that the cardinality of the columns is low enough for the results to be grouped and handled locally. You can adjust these methods to handle other cases but it requires more work.
RDD API
Reshape data from wide to long:
import org.apache.spark.sql.functions._
val exploded = explode(array(
(columnsToDrop ++ columnsToCode).map(c =>
struct(lit(c).alias("k"), col(c).alias("v"))): _*
)).alias("level")
val long = df.select(exploded, $"TARGET")
aggregateByKey
, reshape and collect:
import org.apache.spark.util.StatCounter
val lookup = long.as[((String, String), Int)].rdd
// You can use prefix partitioner (one that depends only on _._1)
// to avoid reshuffling for groupByKey
.aggregateByKey(StatCounter())(_ merge _, _ merge _)
.map { case ((c, v), s) => (c, (v, s)) }
.groupByKey
.mapValues(_.toMap)
.collectAsMap
You can use lookup
to get statistics for individual columns and levels. For example:
lookup("col1")("A")
org.apache.spark.util.StatCounter =
(count: 3, mean: 0.666667, stdev: 0.471405, max: 1.000000, min: 0.000000)
Gives you data for col1
, level A
. Based on the binary TARGET
assumption this information is complete (you get count / fractions for both classes).
You can use lookup like this to generate SQL expressions or pass it to udf
and apply it on individual columns.
DataFrame API
Compute aggregates based on levels:
val stats = long
.groupBy($"level.k", $"level.v")
.agg(mean($"TARGET"), sum($"TARGET"))
Depending on your preferences you can reshape this to enable efficient joins or convert to a local collection and similarly to the RDD solution.