I have an rdd of (String,Int) which is sorted by key
val data = Array((\"c1\",6), (\"c2\",3),(\"c3\",4))
val rdd = sc.parallelize(data).sortByKey
Here is a solution in PySpark. Internally it's essentially the same as @zero323's Scala solution, but it provides a general-purpose function with a Spark-like API.
import numpy as np
def cumsum(rdd, get_summand):
"""Given an ordered rdd of items, computes cumulative sum of
get_summand(row), where row is an item in the RDD.
"""
def cumsum_in_partition(iter_rows):
total = 0
for row in iter_rows:
total += get_summand(row)
yield (total, row)
rdd = rdd.mapPartitions(cumsum_in_partition)
def last_partition_value(iter_rows):
final = None
for cumsum, row in iter_rows:
final = cumsum
return (final,)
partition_sums = rdd.mapPartitions(last_partition_value).collect()
partition_cumsums = list(np.cumsum(partition_sums))
partition_cumsums = [0] + partition_cumsums
partition_cumsums = sc.broadcast(partition_cumsums)
def add_sums_of_previous_partitions(idx, iter_rows):
return ((cumsum + partition_cumsums.value[idx], row)
for cumsum, row in iter_rows)
rdd = rdd.mapPartitionsWithIndex(add_sums_of_previous_partitions)
return rdd
# test for correctness by summing numbers, with and without Spark
rdd = sc.range(10000,numSlices=10).sortBy(lambda x: x)
cumsums, values = zip(*cumsum(rdd,lambda x: x).collect())
assert all(cumsums == np.cumsum(values))