I have an rdd of (String,Int) which is sorted by key
val data = Array((\"c1\",6), (\"c2\",3),(\"c3\",4))
val rdd = sc.parallelize(data).sortByKey
Compute partial results for each partition:
val partials = rdd.mapPartitionsWithIndex((i, iter) => {
val (keys, values) = iter.toSeq.unzip
val sums = values.scanLeft(0)(_ + _)
Iterator((keys.zip(sums.tail), sums.last))
})
Collect partials sums
val partialSums = partials.values.collect
Compute cumulative sum over partitions and broadcast it:
val sumMap = sc.broadcast(
(0 until rdd.partitions.size)
.zip(partialSums.scanLeft(0)(_ + _))
.toMap
)
Compute final results:
val result = partials.keys.mapPartitionsWithIndex((i, iter) => {
val offset = sumMap.value(i)
if (iter.isEmpty) Iterator()
else iter.next.map{case (k, v) => (k, v + offset)}.toIterator
})