How to compute cumulative sum using Spark

后端 未结 5 1049
有刺的猬
有刺的猬 2020-12-01 08:51

I have an rdd of (String,Int) which is sorted by key

val data = Array((\"c1\",6), (\"c2\",3),(\"c3\",4))
val rdd = sc.parallelize(data).sortByKey
         


        
5条回答
  •  感情败类
    2020-12-01 09:00

    Here is a solution in PySpark. Internally it's essentially the same as @zero323's Scala solution, but it provides a general-purpose function with a Spark-like API.

    import numpy as np
    def cumsum(rdd, get_summand):
        """Given an ordered rdd of items, computes cumulative sum of
        get_summand(row), where row is an item in the RDD.
        """
        def cumsum_in_partition(iter_rows):
            total = 0
            for row in iter_rows:
                total += get_summand(row)
                yield (total, row)
        rdd = rdd.mapPartitions(cumsum_in_partition)
    
        def last_partition_value(iter_rows):
            final = None
            for cumsum, row in iter_rows:
                final = cumsum
            return (final,)
    
        partition_sums = rdd.mapPartitions(last_partition_value).collect()
        partition_cumsums = list(np.cumsum(partition_sums))
        partition_cumsums = [0] + partition_cumsums
        partition_cumsums = sc.broadcast(partition_cumsums)
    
        def add_sums_of_previous_partitions(idx, iter_rows):
            return ((cumsum + partition_cumsums.value[idx], row)
                for cumsum, row in iter_rows)
        rdd = rdd.mapPartitionsWithIndex(add_sums_of_previous_partitions)
        return rdd
    
    # test for correctness by summing numbers, with and without Spark
    rdd = sc.range(10000,numSlices=10).sortBy(lambda x: x)
    cumsums, values = zip(*cumsum(rdd,lambda x: x).collect())
    assert all(cumsums == np.cumsum(values))
    

提交回复
热议问题