Aggregate (Sum) over Window for a list of Columns

。_饼干妹妹 提交于 2019-12-23 02:44:13

问题


I'm having trouble finding a generic way to calculate the Sum (or any aggregate function) over a given window, for a list of columns available in the DataFrame.

val inputDF = spark
.sparkContext
.parallelize(
    Seq(
        (1,2,1, 30, 100),
        (1,2,2, 30, 100), 
        (1,2,3, 30, 100),
        (11,21,1, 30, 100),
        (11,21,2, 30, 100), 
        (11,21,3, 30, 100)
    ),
    10)
.toDF("c1", "c2", "offset", "v1", "v2")

input.show
+---+---+------+---+---+
| c1| c2|offset| v1| v2|
+---+---+------+---+---+
|  1|  2|     1| 30|100|
|  1|  2|     2| 30|100|
|  1|  2|     3| 30|100|
| 11| 21|     1| 30|100|
| 11| 21|     2| 30|100|
| 11| 21|     3| 30|100|
+---+---+------+---+---+

Given a DataFrame as shown above, it's easy to find Sum for a list of columns, similar to code snippet shown below -

val groupKey = List("c1", "c2").map(x => col(x.trim))
    val orderByKey = List("offset").map(x => col(x.trim))

    val aggKey = List("v1", "v2").map(c => sum(c).alias(c.trim))

    import org.apache.spark.sql.expressions.Window

    val w = Window.partitionBy(groupKey: _*).orderBy(orderByKey: _*)

    val outputDF = inputDF
    .groupBy(groupKey: _*)
    .agg(aggKey.head, aggKey.tail: _*)

    outputDF.show

But I can't seem to find a similar approach for aggregate functions over a window spec. So far I've only been able to solve this by specifying each column individually as shown below -

val outputDF2 = inputDF
    .withColumn("cumulative_v1", sum(when($"offset".between(-1, 1), inputDF("v1")).otherwise(0)).over(w))
    .withColumn("cumulative_v3", sum(when($"offset".between(-2, 2), inputDF("v1")).otherwise(0)).over(w))

I'd appreciate if there is a way to do this aggregation over a dynamic list of columns. Thanks!


回答1:


I think I found an approach that works better than the one stated in the above problem.

/**
    * Utility method takes a DataFrame and a List of columns to return aggregated values for the specified list of columns
    * @param colsToAggregate    Seq[String] of all columns in the input DataFrame to be aggregated
    * @param inputDF            Input DataFrame
    * @param f                  aggregate function 'call by name'
    * @param partitionByColSeq  Seq[] of column names to partition the inputDF before applying the aggregate
    * @param orderByColSeq      Seq[] of column names to order the inputDF before applying the aggregate
    * @param name_prefix        String to prefix the new columns with, to avoid collisions
    * @param name               New column names. Uses Identify function and reuses aggregated column names
    * @return                   output DataFrame
    */
  def withRollingAggregateColumns(colsToAggregate: Seq[String],
                                  inputDF: DataFrame,
                                  f: String => Column,
                                  partitionByColSeq: Seq[String],
                                  orderByColSeq: Seq[String],
                                  name_prefix: String,
                                  name: String => String = identity) = {

    val groupByKey = partitionByColSeq.map(x => col(x.trim))
    val orderByKey = orderByColSeq.map(x => col(x.trim))

    import org.apache.spark.sql.expressions.Window

    val w = Window.partitionBy(groupByKey: _*).orderBy(orderByKey: _*)

    colsToAggregate
      .foldLeft(inputDF)(
        (df, elementInCols) => df
          .withColumn(
            name_prefix + "_" + name(elementInCols),
            f(elementInCols).over(w)
          )
      )
  }

In this case, the Utility method takes a DataFrame as an input and appends new columns based on the provided function f. It uses the "withColumn" and "foldLeft" syntax to iterate over the list of columns which need to be aggregated. To avoid any column name collisions, it appends a user-provided 'prefix' to the new aggregate columns



来源:https://stackoverflow.com/questions/44513660/aggregate-sum-over-window-for-a-list-of-columns

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!