We have dataframe like below :
+------+--------------------+
| Flag | value|
+------+--------------------+
|1 |5 |
|1 |4
The only way to do without udf
it's using high order functions:
Click here to see step by step on Databricks (valid until 30/06/2021)
from pyspark.sql import Row
from pyspark.sql.window import Window
import pyspark.sql.functions as f
df = spark.createDataFrame(
[Row(Flag=1, value=5), Row(Flag=1, value=4), Row(Flag=1, value=3), Row(Flag=1, value=5), Row(Flag=1, value=6),
Row(Flag=1, value=4), Row(Flag=1, value=7), Row(Flag=1, value=5), Row(Flag=1, value=2), Row(Flag=1, value=3),
Row(Flag=1, value=2), Row(Flag=1, value=6), Row(Flag=1, value=9)]
)
window = Window.partitionBy('flag')
df = df.withColumn('row_id', f.row_number().over(window.orderBy('flag')).cast('int'))
df = df.withColumn('values', f.collect_list('value').over(window).cast('array'))
expr = "TRANSFORM(slice(values, 1, row_id), sliced_array -> sliced_array)"
df = df.withColumn('sliced_array', f.expr(expr))
expr = "REDUCE(sliced_array, 0, (c, n) -> IF(c < 20, c + n, n))"
df = df.select('flag', 'value', f.expr(expr).alias('cumsum'))
df.show()
Output:
+----+-----+------+
|flag|value|cumsum|
+----+-----+------+
| 1| 5| 5|
| 1| 4| 9|
| 1| 3| 12|
| 1| 5| 17|
| 1| 6| 23|
| 1| 4| 4|
| 1| 7| 11|
| 1| 5| 16|
| 1| 2| 18|
| 1| 3| 21|
| 1| 2| 2|
| 1| 6| 8|
| 1| 9| 17|
+----+-----+------+