Cumulative product in Spark?

回眸只為那壹抹淺笑 提交于 2019-11-28 13:13:11

As long as the number are strictly positive (0 can be handled as well, if present, using coalesce) as in your example, the simplest solution is to compute the sum of logarithms and take the exponential:

import org.apache.spark.sql.functions.{exp, log, max, sum}

val df = Seq(
  ("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3), 
  ("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5), 
  ("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")

val result = df
  .groupBy("A", "B")
  .agg(
    max($"date").as("date"), 
    exp(sum(log($"val"))).as("val"))

Since this uses FP arithmetic the result won't be exact:

result.show
+---+---+--------+------------------+
|  A|  B|    date|               val|
+---+---+--------+------------------+
| hh| jj|20171105|104.99999999999997|
| rr| gg|20171105|47.999999999999986|
+---+---+--------+------------------+

but after rounding should good enough for majority of applications.

result.withColumn("val", round($"val")).show
+---+---+--------+-----+
|  A|  B|    date|  val|
+---+---+--------+-----+
| hh| jj|20171105|105.0|
| rr| gg|20171105| 48.0|
+---+---+--------+-----+

If that's not enough you can define an UserDefinedAggregateFunction or Aggregator (How to define and use a User-Defined Aggregate Function in Spark SQL?) or use functional API with reduceGroups:

import scala.math.Ordering

case class Record(A: String, B: String, date: String, value: Long)

df.withColumnRenamed("val", "value").as[Record]
  .groupByKey(x => (x.A, x.B))
  .reduceGroups((x, y) => x.copy(
    date = Ordering[String].max(x.date, y.date),
    value = x.value * y.value))
  .toDF("key", "value")
  .select($"value.*")
  .show
+---+---+--------+-----+
|  A|  B|    date|value|
+---+---+--------+-----+
| hh| jj|20171105|  105|
| rr| gg|20171105|   48|
+---+---+--------+-----+

You can solve this using either collect_list+UDF or an UDAF. UDAF may be more efficient, but harder to implement due to the local aggregation.

If you have a dataframe like this :

+---+---+
|key|val|
+---+---+
|  a|  1|
|  a|  2|
|  a|  3|
|  b|  4|
|  b|  5|
+---+---+

You can invoke an UDF :

val prod = udf((vals:Seq[Int]) => vals.reduce(_ * _))

df
  .groupBy($"key")
  .agg(prod(collect_list($"val")).as("val"))
  .show()

+---+---+
|key|val|
+---+---+
|  b| 20|
|  a|  6|
+---+---+
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!