Spark-SQL Window functions on Dataframe - Finding first timestamp in a group

匆匆过客 提交于 2019-12-11 12:48:39

问题


I have below dataframe (say UserData).

uid region  timestamp
a   1   1
a   1   2
a   1   3
a   1   4
a   2   5
a   2   6
a   2   7
a   3   8
a   4   9
a   4   10
a   4   11
a   4   12
a   1   13
a   1   14
a   3   15
a   3   16
a   5   17
a   5   18
a   5   19
a   5   20

This data is nothing but user (uid) travelling across different regions (region) at different time (timestamp). Presently, timestamp is shown as 'int' for simplicity. Note that above dataframe will not be necessarily in increasing order of timestamp. Also, there may be some rows in between from different users. I have shown dataframe for single user only in monotonically incrementing order of timestamp for simplicity.

My goal is - to find User 'a' spent how much time in each region and in what order? So My final expected output looks like

uid region  regionTimeStart regionTimeEnd
a   1   1   5
a   2   5   8
a   3   8   9
a   4   9   13
a   1   13  15
a   3   15  17
a   5   17  20

Based on my findings, Spark SQL Window functions can be used for this purpose. I have tried below things,

val w = Window
  .partitionBy("region")
  .partitionBy("uid")
  .orderBy("timestamp")

val resultDF = UserData.select(
  UserData("uid"), UserData("timestamp"),
  UserData("region"), rank().over(w).as("Rank"))

But here onwards, I am not sure on how to get regionTimeStart and regionTimeEnd columns. regionTimeEnd column is nothing but 'lead' of regionTimeStart except the last entry in group.

I see Aggregate operations have 'first' and 'last' functions but for that I need to group data based on ('uid','region') which spoils monotonically increasing order of path traversed i.e. at time 13,14 user has come back to region '1' and I want that to be retained instead of clubbing it with initial region '1' at time 1.

It would be very helpful if anyone one can guide me. I am new to Spark and I have better understanding of Scala Spark APIs compared to Python/JAVA Spark APIs.


回答1:


Window functions are indeed useful although your approach can work only if you assume that user visits given region only once. Also window definition you use is incorrect - multiple calls to partitionBy simply return new objects with different window definitions. If you want to partition by multiple columns you should pass them in a single call (.partitionBy("region", "uid")).

Lets start with marking continuous visits in each region:

import org.apache.spark.sql.functions.{lag, sum, not}
import org.apache.spark.sql.expressions.Window 

val w = Window.partitionBy($"uid").orderBy($"timestamp")

val change = (not(lag($"region", 1).over(w) <=> $"region")).cast("int")
val ind = sum(change).over(w)

val dfWithInd = df.withColumn("ind", ind)

Next you we simply aggregate over the groups and find leads:

import org.apache.spark.sql.functions.{lead, coalesce}

val regionTimeEnd = coalesce(lead($"timestamp", 1).over(w), $"max_")

val result = dfWithInd
  .groupBy($"uid", $"region", $"ind")
  .agg(min($"timestamp").alias("timestamp"), max($"timestamp").alias("max_"))
  .drop("ind")
  .withColumn("regionTimeEnd", regionTimeEnd)
  .withColumnRenamed("timestamp", "regionTimeStart")
  .drop("max_")

result.show

// +---+------+---------------+-------------+
// |uid|region|regionTimeStart|regionTimeEnd|
// +---+------+---------------+-------------+
// |  a|     1|              1|            5|
// |  a|     2|              5|            8|
// |  a|     3|              8|            9|
// |  a|     4|              9|           13|
// |  a|     1|             13|           15|
// |  a|     3|             15|           17|
// |  a|     5|             17|           20|
// +---+------+---------------+-------------+


来源:https://stackoverflow.com/questions/35315840/spark-sql-window-functions-on-dataframe-finding-first-timestamp-in-a-group

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!