问题
I've written a relatively simple Spark job in Scala which reads some data from S3, performs some transformations and aggregations and finally stores the results into a repository.
At the final stage, I have an RDD of my domain model and I would like to group them into chunks of elements so that I can do some mass insertions in my repository.
I used the RDDFunctions.sliding
method to achieve that and it's working almost fine. Here is a simplified version of my code:
val processedElements: RDD[DomainModel] = _
RDDFunctions.fromRDD(processedElements)
.sliding(500, 500)
.foreach { elementsChunk =>
Await.ready(repository.bulkInsert(elementsChunk), 1.minute)
}
The problem is that if for example I have 1020 elements, only 1000 elements end up in my repository. It looks like sliding ignores any additional elements if the window size is larger than the amount of remaining elements.
Is there any way to resolve this? If not, is there any other way to achieve the same behaviour without using RDDFunctions.sliding
?
回答1:
Couldn't you just use foreachPartition
and manual batch management?
fromRDD.foreachPartition(items: Iterator[DomainModel] => {
val batch = new ArrayBuffer[DomainModel](BATCH_SIZE)
while (items.hasNext) {
if (batch.size >= BATCH_SIZE) {
bulkInsert(batch)
batch.clear()
}
batch += items.next
}
if (!batch.isEmpty) {
bulkInsert(batch)
}
})
回答2:
You're right that Spark's sliding
(unlike Scala's), would generate an empty RDD if the window size exceeds the number of remaining items, according to the RDDFunctions doc. Nor does Spark have an equivalence of Scala's grouped
.
If you know how many groups you'll create, a potentially applicable work-around is to split the RDD with modulo
filters. Here's a trivialized example of splitting the RDD into 5 groups:
val rdd = sc.parallelize(Seq(
(0, "text0"), (1, "text1"), (2, "text2"), (3, "text2"), (4, "text2"), (5, "text5"),
(6, "text6"), (7, "text7"), (8, "text8"), (9, "text9"), (10, "text10"), (11, "text11")
))
def g(n:Int)(x: Int): Boolean = { x % 5 == n }
val rddList = (0 to 4).map( n => rdd.filter(x => g(n)(x._1)) )
(0 to 4).foreach(n => rddList(n).collect.foreach(println))
(0,text0)
(5,text5)
(10,text10)
(1,text1)
(6,text6)
(11,text11)
(2,text2)
(7,text7)
(3,text2)
(8,text8)
(4,text2)
(9,text9)
来源:https://stackoverflow.com/questions/43877678/spark-split-rdd-elements-into-chunks