i am learning spark, but i can\'t understand this function combineByKey.
>>> data = sc.parallelize([(\"A\",1),(\"A\",2),(\"B\",1),(\"B\
Here is an example of combineByKey. The objective is to find a per key average of the input data.
scala> val kvData = Array(("a",1),("b",2),("a",3),("c",9),("b",6))
kvData: Array[(String, Int)] = Array((a,1), (b,2), (a,3), (c,9), (b,6))
scala> val kvDataDist = sc.parallelize(kvData,5)
kvDataDist: org.apache.spark.rdd.RDD[(String, Int)] = ParallelCollectionRDD[0] at parallelize at :26
scala> val keyAverages = kvDataDist.combineByKey(x=>(x,1),(a: (Int,Int),x: Int)=>(a._1+x,a._2+1),(b: (Int,Int),c: (Int,Int))=>(b._1+c._1,b._2+c._2))
keyAverages: org.apache.spark.rdd.RDD[(String, (Int, Int))] = ShuffledRDD[4] at combineByKey at :25
scala> keyAverages.collect
res0: Array[(String, (Int, Int))] = Array((c,(9,1)), (a,(4,2)), (b,(8,2)))
scala> val keyAveragesFinal = keyAverages.map(x => (x._1,x._2._1/x._2._2))
keyAveragesFinal: org.apache.spark.rdd.RDD[(String, Int)] = MapPartitionsRDD[3] at map at :25
scala> keyAveragesFinal.collect
res1: Array[(String, Int)] = Array((c,9), (a,2), (b,4))
combineByKey takes 3 functions as arguments:
Function 1 = createCombiner : Called once per key 'k', in each partition
Function 2 = mergeValue : Called as many times as the occurrence of key 'k' within the partition - 1
Function 3 = mergeCombiners : Called as many times as the partitions in which the key exists