Get Top 3 values for every key in a RDD in Spark

折月煮酒 提交于 2019-12-01 09:20:02

问题


I'm a beginner with Spark and I am trying to create an RDD that contains the top 3 values for every key, (Not just the top 3 values). My current RDD contains thousands of entries in the following format:

(key, String, value)

So imagine I had an RDD with content like this:

[("K1", "aaa", 6), ("K1", "bbb", 3), ("K1", "ccc", 2), ("K1", "ddd", 9),
("B1", "qwe", 4), ("B1", "rty", 7), ("B1", "iop", 8), ("B1", "zxc", 1)]

I can currently display the top 3 values in the RDD like so:

("K1", "ddd", 9)
("B1", "iop", 8)
("B1", "rty", 7)

Using:

top3RDD = rdd.takeOrdered(3, key = lambda x: x[2])

Instead what I want is to gather the top 3 values for every key in the RDD so I would like to return this instead:

("K1", "ddd", 9)
("K1", "aaa", 6)
("K1", "bbb", 3)
("B1", "iop", 8)
("B1", "rty", 7)
("B1", "qwe", 4)

回答1:


You need to groupBy the key and then you can use heapq.nlargest to take the top 3 values from each group:

from heapq import nlargest
rdd.groupBy(
    lambda x: x[0]
).flatMap(
    lambda g: nlargest(3, g[1], key=lambda x: x[2])
).collect()

[('B1', 'iop', 8), 
 ('B1', 'rty', 7), 
 ('B1', 'qwe', 4), 
 ('K1', 'ddd', 9), 
 ('K1', 'aaa', 6), 
 ('K1', 'bbb', 3)]



回答2:


If you're open to converting your rdd to a DataFrame, you can define a Window to partition by the key and sort by the value descending. Use this Window to compute the row number, and pick the rows where the row number is less than or equal to 3.

import pyspark.sql.functions as f
import pyspark.sql.Window

w = Window.partitionBy("key").orderBy(f.col("value").desc())

rdd.toDF(["key", "String", "value"])\
    .select("*", f.row_number().over(w).alias("rowNum"))\
    .where(f.col("rowNum") <= 3)\
    .drop("rowNum")
    .show()


来源:https://stackoverflow.com/questions/49713886/get-top-3-values-for-every-key-in-a-rdd-in-spark

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!