Pyspark : Cumulative Sum with reset condition

后端 未结 2 2074
盖世英雄少女心
盖世英雄少女心 2021-02-09 04:49

We have dataframe like below :

+------+--------------------+
| Flag |               value|
+------+--------------------+
|1     |5                   |
|1     |4          


        
2条回答
  •  不要未来只要你来
    2021-02-09 05:37

    The only way to do without udf it's using high order functions:

    Click here to see step by step on Databricks (valid until 30/06/2021)

    Spark >= 2.4.x

    from pyspark.sql import Row
    from pyspark.sql.window import Window
    import pyspark.sql.functions as f
    
    
    df = spark.createDataFrame(
        [Row(Flag=1, value=5), Row(Flag=1, value=4), Row(Flag=1, value=3), Row(Flag=1, value=5), Row(Flag=1, value=6),
         Row(Flag=1, value=4), Row(Flag=1, value=7), Row(Flag=1, value=5), Row(Flag=1, value=2), Row(Flag=1, value=3),
         Row(Flag=1, value=2), Row(Flag=1, value=6), Row(Flag=1, value=9)]
    )
    
    window = Window.partitionBy('flag')
    df = df.withColumn('row_id', f.row_number().over(window.orderBy('flag')).cast('int'))
    df = df.withColumn('values', f.collect_list('value').over(window).cast('array'))
    
    expr = "TRANSFORM(slice(values, 1, row_id), sliced_array -> sliced_array)"
    df = df.withColumn('sliced_array', f.expr(expr))
    
    expr = "REDUCE(sliced_array, 0, (c, n) -> IF(c < 20, c + n, n))"
    df = df.select('flag', 'value', f.expr(expr).alias('cumsum'))
    
    df.show()
    

    Output:

    +----+-----+------+
    |flag|value|cumsum|
    +----+-----+------+
    |   1|    5|     5|
    |   1|    4|     9|
    |   1|    3|    12|
    |   1|    5|    17|
    |   1|    6|    23|
    |   1|    4|     4|
    |   1|    7|    11|
    |   1|    5|    16|
    |   1|    2|    18|
    |   1|    3|    21|
    |   1|    2|     2|
    |   1|    6|     8|
    |   1|    9|    17|
    +----+-----+------+
    

提交回复
热议问题