Explain the aggregate functionality in Spark

前端 未结 9 2211
误落风尘
误落风尘 2020-12-07 12:23

I am looking for some better explanation of the aggregate functionality that is available via spark in python.

The example I have is as follows (using pyspark from

9条回答
  •  无人及你
    2020-12-07 12:52

    You can use the following code (in scala) to see precisely what aggregate is doing. It builds a tree of all the addition and merge operations:

    sealed trait Tree[+A]
    case class Leaf[A](value: A) extends Tree[A]
    case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
    
    val zero : Tree[Int] = Leaf(0)
    val rdd = sc.parallelize(1 to 4).repartition(3)
    

    And then, in the shell:

    scala> rdd.glom().collect()
    res5: Array[Array[Int]] = Array(Array(4), Array(1, 2), Array(3))
    

    So, we have these 3 partitions: [4], [1,2], and [3].

    scala> rdd.aggregate(zero)((l,r)=>Branch(l, Leaf(r)), (l,r)=>Branch(l,r))
    res11: Tree[Int] = Branch(Branch(Branch(Leaf(0),Branch(Leaf(0),Leaf(4))),Branch(Leaf(0),Leaf(3))),Branch(Branch(Leaf(0),Leaf(1)),Leaf(2)))
    

    You can represent the result as a tree:

    +
    | \__________________
    +                    +
    | \________          | \
    +          +         +   2
    | \        | \       | \         
    0  +       0  3      0  1
       | \
       0  4
    

    You can see that a first zero element is created on the driver node (at the left of the tree), and then, the results for all the partitions are merged one by one. You also see that if you replace 0 by 1 as you did in your question, it will add 1 to each result on each partition, and also add 1 to the initial value on the driver. So, the total number of time the zero value you give is used is:

    number of partitions + 1.

    So, in your case, the result of

    aggregate(
      (X, Y),
      (lambda acc, value: (acc[0] + value, acc[1] + 1)),
      (lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1])))
    

    will be:

    (sum(elements) + (num_partitions + 1)*X, count(elements) + (num_partitions + 1)*Y)
    

    The implementation of aggregate is quite simple. It is defined in RDD.scala, line 1107:

      def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
        // Clone the zero value since we will also be serializing it as part of tasks
        var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())
        val cleanSeqOp = sc.clean(seqOp)
        val cleanCombOp = sc.clean(combOp)
        val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
        val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
        sc.runJob(this, aggregatePartition, mergeResult)
        jobResult
    }
    

提交回复
热议问题