How to compute cumulative sum using Spark

后端 未结 5 1028
有刺的猬
有刺的猬 2020-12-01 08:51

I have an rdd of (String,Int) which is sorted by key

val data = Array((\"c1\",6), (\"c2\",3),(\"c3\",4))
val rdd = sc.parallelize(data).sortByKey
         


        
5条回答
  •  [愿得一人]
    2020-12-01 09:04

    1. Compute partial results for each partition:

      val partials = rdd.mapPartitionsWithIndex((i, iter) => {
        val (keys, values) = iter.toSeq.unzip
        val sums  = values.scanLeft(0)(_ + _)
        Iterator((keys.zip(sums.tail), sums.last))
      })
      
    2. Collect partials sums

      val partialSums = partials.values.collect
      
    3. Compute cumulative sum over partitions and broadcast it:

      val sumMap = sc.broadcast(
        (0 until rdd.partitions.size)
          .zip(partialSums.scanLeft(0)(_ + _))
          .toMap
      )
      
    4. Compute final results:

      val result = partials.keys.mapPartitionsWithIndex((i, iter) => {
        val offset = sumMap.value(i)
        if (iter.isEmpty) Iterator()
        else iter.next.map{case (k, v) => (k, v + offset)}.toIterator
      })
      

提交回复
热议问题