pySpark forEachPartition - Where is code executed

只谈情不闲聊 提交于 2020-12-30 02:55:06

问题


I'm using pySpark in version 2.3 (cannot update to 2.4 in my current dev-System) and have the following questions concerning the foreachPartition.

First a little context: As far as I understood pySpark-UDFs force the Python-code to be executed outside the Java Virtual Machine (JVM) in a Python-instance, making it performance-costing. Since I need to apply some Python-functions to my data and want to minimize overhead costs, I had the idea to at least load a handable bunch of data into the driver and process it as Pandas-DataFrame. Anyhow, this would lead to a loss of the parallelism-advantage Spark has. Then I read that foreachPartition applies a function to all the data within a partition and, hence, allows parallel processing.

My questions now are:

  1. When I apply a Python-function via foreachPartition, does the Python-execution take place within the driver process (and the partition-data is therefore transfered over the network to my driver)?

  2. Is the data processed row-wise within foreachPartition (meaning every RDD-row is transfered one by one to the Python-instance), or is the partition-data processed at once (meaning, for example, the whole partition is transfered to the instance and is handled as whole by one Python-instance)?

Thank you in advance for your input!


Edit: My current "in driver" - solution I used before forEachPartition in order to process entire batches looked like:

def partition_generator(rdd):
    glom = rdd.glom()
    #Optionally persist glom
    for partition in range(rdd.getNumPartitions()):
        yield glom.map(lambda row: row[partition]).collect()

A little explanation about what happens here: glom groups the respective rows of all partitions to a list. Taken from the docs:

glom(self): Return an RDD created by coalescing all elements within each partition into a list.

So the for-loop iterates over the number of available partitions (getNumPartitions()) and in every iteration a partition is yielded within the driver (glom.map(lambda row: row[partition]).collect()).


回答1:


Luckily I stumbled upon this great explanation of mapPartitions from Mrinal (answered here).

mapPartitions applies a function on each partition of an RDD. Hence, parallelization can be used if the partitions are distributed over different nodes. The corresponding Python-instances, which are necessary for processing the Python-functions, are created on these nodes. While foreachPartition only applies a function (e.g. write your data in a .csv-file), mapPartitions also returns a new RDD. Therefore, using foreachPartition was the wrong choice for me.

In order to answer my second question: Functions like map or UDFs create a new Python-instance and pass data from the DataFrame/RDD row-by-row, resulting in a lot of overhead. foreachPartition and mapPartitions (both RDD-functions) transfer an entire partition to a Python-instance.

Additionally, using generators also reduces the amount of memory necessary for iterating over this transferred partition data (partitions are handled as iterator objects, while each row is then processed by iterating over this object).

An example might look like:

def generator(partition):
    """
    Function yielding some result created by some function applied to each row of a partition (in this case lower-casing a string)

    @partition: iterator-object of partition
    """

    for row in partition:
        yield [word.lower() for word in row["text"]]


df = spark.createDataFrame([(["TESTA"], ), (["TESTB"], )], ["text"])
df = df.repartition(2)
df.rdd.mapPartitions(generator).toDF(["text"]).show()


#Result:
+-----+
| text|
+-----+
|testa|
|testb|
+-----+

Hope this helps somebody facing similar problems :)




回答2:


pySpark UDFs execute near the executors - i.e. in a sperate python instance, per executor, that runs side-by-side and passes data back and forth between the spark engine (scala) and the python interpreter.

the same is true for calls to udfs inside a foreachPartition

Edit - after looking at the sample code

  1. using RDDs is not an efficient way of using spark - you should move to datasets
  2. what makes your code sync all data to the driver is the collect()
  3. foreachParition will be similar to glom


来源:https://stackoverflow.com/questions/55654982/pyspark-foreachpartition-where-is-code-executed

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!