Retrieve top n in each group of a DataFrame in pyspark

匿名 (未验证) 提交于 2019-12-03 07:50:05

问题:

There's a DataFrame in pyspark with data as below:

user_id object_id score user_1  object_1  3 user_1  object_1  1 user_1  object_2  2 user_2  object_1  5 user_2  object_2  2 user_2  object_2  6 

What I expect is returning 2 records in each group with the same user_id, which need to have the highest score. Consequently, the result should look as the following:

user_id object_id score user_1  object_1  3 user_1  object_2  2 user_2  object_2  6 user_2  object_1  5 

I'm really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem? Great thanks!

回答1:

I believe you need to use window functions to attain the rank of each row based on user_id and score, and subsequently filter your results to only keep the first two values.

from pyspark.sql.window import Window from pyspark.sql.functions import rank, col  window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())  df.select('*', rank().over(window).alias('rank'))    .filter(col('rank') 

In general, the official programming guide is a good place to start learning Spark.

Data

rdd = sc.parallelize([("user_1",  "object_1",  3),                        ("user_1",  "object_2",  2),                        ("user_2",  "object_1",  5),                        ("user_2",  "object_2",  2),                        ("user_2",  "object_2",  6)]) df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"]) 


回答2:

Top-n is more accurate if using row_number instead of rank when getting rank equality:

val n = 5 df.select(col('*'), row_number().over(window).alias('row_number')) \   .where(col('row_number') 

Note limit(20).toPandas() trick instead of show() for Jupyter notebooks for nicer formatting.



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!