Take top N elements from each group in PySpark RDD (without using groupByKey)

我怕爱的太早我们不能终老 提交于 2019-12-24 05:21:35

问题


I have an RDD like the following

dataSource = sc.parallelize( [("user1", (3, "blue")), ("user1", (4, "black")), ("user2", (5, "white"), ("user2", (3, "black")), ("user2", (6, "red")), ("user1", (1, "red"))] )

I want to use reduceByKey to find Top 2 colors for each user so the output would be an RDD like:

sc.parallelize([("user1", ["black", "blue"]), ("user2", ["red", "white"])])

so I need to reduce by key and then sort each key's values, i.e. (number, color) on number and return top n colors.

I don't want to use groupBy. If there is anything better than reduceByKey other than groupBy, it would be great :)


回答1:


You can for example use a heap queue. Required imports:

import heapq
from functools import partial

Helper functions:

def zero_value(n):
    """Initialize a queue. If n is large
    it could be more efficient to track a number of the elements
    on heap (cnt, heap) and switch between heappush and heappushpop
    if we exceed n. I leave this as an exercise for the reader."""
    return [(float("-inf"), None) for _ in range(n)]

def seq_func(acc, x):
    heapq.heappushpop(acc, x)
    return acc

def merge_func(acc1, acc2, n):
    return heapq.nlargest(n, heapq.merge(acc1, acc2))

def finalize(kvs):
    return [v for (k, v) in kvs if k != float("-inf")]

Data:

rdd = sc.parallelize([
    ("user1", (3, "blue")), ("user1", (4, "black")),
    ("user2", (5, "white")), ("user2", (3, "black")),
    ("user2", (6, "red")), ("user1", (1, "red"))])

Solution:

(rdd
    .aggregateByKey(zero_value(2), seq_func, partial(merge_func, n=2))
    .mapValues(finalize)
    .collect())

Result:

[('user2', ['red', 'white']), ('user1', ['black', 'blue'])]


来源:https://stackoverflow.com/questions/41903310/take-top-n-elements-from-each-group-in-pyspark-rdd-without-using-groupbykey

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!