Spark - Window with recursion? - Conditionally propagating values across rows

匿名 (未验证) 提交于 2019-12-03 01:27:01

问题:

I have the following dataframe showing the revenue of purchases.

+-------+--------+-------+ |user_id|visit_id|revenue| +-------+--------+-------+ |      1|       1|      0| |      1|       2|      0| |      1|       3|      0| |      1|       4|    100| |      1|       5|      0| |      1|       6|      0| |      1|       7|    200| |      1|       8|      0| |      1|       9|     10| +-------+--------+-------+

Ultimately I want the new column purch_revenue to show the revenue generated by the purchase in every row. As a workaround, I have also tried to introduce a purchase identifier purch_id which is incremented each time a purchase was made. So this is listed just as a reference.

+-------+--------+-------+-------------+--------+ |user_id|visit_id|revenue|purch_revenue|purch_id| +-------+--------+-------+-------------+--------+ |      1|       1|      0|          100|       1| |      1|       2|      0|          100|       1| |      1|       3|      0|          100|       1| |      1|       4|    100|          100|       1| |      1|       5|      0|          100|       2| |      1|       6|      0|          100|       2| |      1|       7|    200|          100|       2| |      1|       8|      0|          100|       3| |      1|       9|     10|          100|       3| +-------+--------+-------+-------------+--------+

I've tried to use the lag/lead function like this:

user_timeline = Window.partitionBy("user_id").orderBy("visit_id") find_rev = fn.when(fn.col("revenue") > 0,fn.col("revenue"))\    .otherwise(fn.lead(fn.col("revenue"), 1).over(user_timeline)) df.withColumn("purch_revenue", find_rev)

This duplicates the revenue column if revenue > 0 and also pulls it up by one row. Clearly, I can chain this for a finite N, but that's not a solution.

  • Is there a way to apply this recursively until revenue > 0?
  • Alternatively, is there a way to increment a value based on a condition? I've tried to figure out a way to do that but struggled to find one.

回答1:

Window functions don't support recursion but it is not required here. This type of sesionization can be easily handled with cumulative sum:

from pyspark.sql.functions import col, sum, when, lag from pyspark.sql.window import Window  w = Window.partitionBy("user_id").orderBy("visit_id") purch_id = sum(lag(when(     col("revenue") > 0, 1).otherwise(0),      1, 0 ).over(w)).over(w) + 1  df.withColumn("purch_id", purch_id).show()
+-------+--------+-------+--------+ |user_id|visit_id|revenue|purch_id| +-------+--------+-------+--------+ |      1|       1|      0|       1| |      1|       2|      0|       1| |      1|       3|      0|       1| |      1|       4|    100|       1| |      1|       5|      0|       2| |      1|       6|      0|       2| |      1|       7|    200|       2| |      1|       8|      0|       3| |      1|       9|     10|       3| +-------+--------+-------+--------+


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!