问题
I'm trying to filter an RDD of tuples to return the largest N tuples based on key values. I need the return format to be an RDD.
So the RDD:
[(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')]
filtered for the largest 3 keys should return the RDD:
[(6,'p'), (12,'e'), (49,'y')]
Doing a sortByKey()
and then take(N)
returns the values and doesn't result in an RDD, so that won't work.
I could return all of the keys, sort them, find the Nth largest value, and then filter the RDD for key values greater than that, but that seems very inefficient.
What would be the best way to do this?
回答1:
With RDD
A quick but not particularly efficient solution is to follow sortByKey
use zipWithIndex
and filter
:
n = 3
rdd = sc.parallelize([(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')])
rdd.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
If n is relatively small compared to RDD size a little bit more efficient approach is to avoid full sort:
import heapq
def key(kv):
return kv[0]
top_per_partition = rdd.mapPartitions(lambda iter: heapq.nlargest(n, iter, key))
top_per_partition.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
If keys are much smaller than values and order of final output doesn't matter then filter
approach can work just fine:
keys = rdd.keys()
identity = lambda x: x
offset = (keys
.mapPartitions(lambda iter: heapq.nlargest(n, iter))
.sortBy(identity)
.zipWithIndex()
.filter(lambda xi: xi[1] < n)
.keys()
.max())
rdd.filter(lambda kv: kv[0] <= offset)
Also it won't keep exact n values in case of ties.
With DataFrames
You can just orderBy
and limit
:
from pyspark.sql.functions import col
rdd.toDF().orderBy(col("_1").desc()).limit(n)
回答2:
A less effort approach since you only want to convert take(N)
results to new RDD.
sc.parallelize(yourSortedRdd.take(Nth))
来源:https://stackoverflow.com/questions/34292879/return-rdd-of-largest-n-values-from-another-rdd-in-spark