As the subject describes, I have a PySpark Dataframe that I need to cast two columns into a new column that is a list of tuples based the value of a third column. This cast will reduce or flatten the dataframe by a key value, product id in this case, and the result os one row per key.
There are hundreds of millions of rows in this dataframe, with 37M unique product ids. Therefore I need a way to do the transformation on the spark cluster without bringing back any data to the driver (Jupyter in this case).
Here is an extract of my dataframe for just 1 product:
+-----------+-------------------+-------------+--------+----------+---------------+
| product_id| purchase_date|days_warranty|store_id|year_month| category|
+-----------+-------------------+-----------+----------+----------+---------------+
|02147465400|2017-05-16 00:00:00| 30| 205| 2017-05| CATEGORY A|
|02147465400|2017-04-15 00:00:00| 30| 205| 2017-04| CATEGORY A|
|02147465400|2018-07-11 00:00:00| 30| 205| 2018-07| CATEGORY A|
|02147465400|2017-06-14 00:00:00| 30| 205| 2017-06| CATEGORY A|
|02147465400|2017-03-16 00:00:00| 30| 205| 2017-03| CATEGORY A|
|02147465400|2017-08-14 00:00:00| 30| 205| 2017-08| CATEGORY A|
|02147465400|2017-09-12 00:00:00| 30| 205| 2017-09| CATEGORY A|
|02147465400|2017-01-21 00:00:00| 30| 205| 2017-01| CATEGORY A|
|02147465400|2018-08-14 00:00:00| 30| 205| 2018-08| CATEGORY A|
|02147465400|2018-08-23 00:00:00| 30| 205| 2018-08| CATEGORY A|
|02147465400|2017-10-11 00:00:00| 30| 205| 2017-10| CATEGORY A|
|02147465400|2017-12-12 00:00:00| 30| 205| 2017-12| CATEGORY A|
|02147465400|2017-02-15 00:00:00| 30| 205| 2017-02| CATEGORY A|
|02147465400|2018-04-12 00:00:00| 30| 205| 2018-04| CATEGORY A|
|02147465400|2018-03-12 00:00:00| 30| 205| 2018-03| CATEGORY A|
|02147465400|2018-05-15 00:00:00| 30| 205| 2018-05| CATEGORY A|
|02147465400|2018-02-12 00:00:00| 30| 205| 2018-02| CATEGORY A|
|02147465400|2018-06-14 00:00:00| 30| 205| 2018-06| CATEGORY A|
|02147465400|2018-01-11 00:00:00| 30| 205| 2018-01| CATEGORY A|
|02147465400|2017-07-20 00:00:00| 30| 205| 2017-07| CATEGORY A|
|02147465400|2017-11-11 00:00:00| 30| 205| 2017-11| CATEGORY A|
|02147465400|2017-01-05 00:00:00| 90| 205| 2017-01| CATEGORY B|
|02147465400|2017-01-21 00:00:00| 90| 205| 2017-01| CATEGORY B|
|02147465400|2017-10-09 00:00:00| 90| 205| 2017-10| CATEGORY B|
|02147465400|2018-07-11 00:00:00| 90| 205| 2018-07| CATEGORY B|
|02147465400|2017-04-16 00:00:00| 90| 205| 2017-04| CATEGORY B|
|02147465400|2018-09-16 00:00:00| 90| 205| 2018-09| CATEGORY B|
|02147465400|2018-04-14 00:00:00| 90| 205| 2018-04| CATEGORY B|
|02147465400|2018-01-12 00:00:00| 90| 205| 2018-01| CATEGORY B|
|02147465400|2017-07-15 00:00:00| 90| 205| 2017-07| CATEGORY B|
+-----------+-------------------+-----------+----------+----------+---------------+
Here is the desired resulting dataframe, one row for the one product, where the rows of the original dataframe have the purchase_date and days_warranty columns cast as an array of tuples into new columns based on the category column value:
+-----------+----------------------------+----------------------------+
| product_id| CATEGORY A| CATEGORY B|
+-----------+----------------------------+----------------------------+
|02147465400| [ (2017-05-16 00:00:00,30),| [ (2017-01-05 00:00:00,90),|
| | (2017-04-15 00:00:00,30),| (2017-01-21 00:00:00,90),|
| | (2018-07-11 00:00:00,30),| (2017-10-09 00:00:00,90),|
| | (2017-06-14 00:00:00,30),| (2018-07-11 00:00:00,90),|
| | (2017-03-16 00:00:00,30),| (2017-04-16 00:00:00,90),|
| | (2017-08-14 00:00:00,30),| (2018-09-16 00:00:00,90),|
| | (2017-09-12 00:00:00,30),| (2018-04-14 00:00:00,90),|
| | (2017-01-21 00:00:00,30),| (2018-01-12 00:00:00,90),|
| | (2018-08-14 00:00:00,30),| (2017-07-15 00:00:00,90) |
| | (2018-08-23 00:00:00,30),| ] |
| | (2017-10-11 00:00:00,30),| |
| | (2017-12-12 00:00:00,30),| |
| | (2017-02-15 00:00:00,30),| |
| | (2018-04-12 00:00:00,30),| |
| | (2018-03-12 00:00:00,30),| |
| | (2018-05-15 00:00:00,30),| |
| | (2018-02-12 00:00:00,30),| |
| | (2018-06-14 00:00:00,30),| |
| | (2018-01-11 00:00:00,30),| |
| | (2017-07-20 00:00:00,30) | |
| | ] |
+-----------+----------------------------+----------------------------+
Assuming your Dataframe
is called df
:
from pyspark.sql.functions import struct
from pyspark.sql.functions import collect_list
gdf = (df.select("product_id", "category", struct("purchase_date", "warranty_days").alias("pd_wd"))
.groupBy("product_id")
.pivot("category")
.agg(collect_list("pd_wd")))
Essentially, you have to group the purchase_date
and warranty_days
into a single column using struct()
. Then, you are just grouping by product_id
, pivoting by category
, can aggregating as collect_list()
.
In the case that you have performance issues with pivot the approach below is another solution to the same problem although it allows you to have more control by splitting the job into phases for each category with a for loop. For every iteration this will append the new data for the category_x into acc_df which will hold the accumulated results.
schema = ArrayType(
StructType((
StructField("p_date", StringType(), False),
StructField("d_warranty", StringType(), False)
))
)
tuple_list_udf = udf(tuple_list, schema)
buf_size = 5 # if you get OOM error decrease this to persist more often
categories = df.select("category").distinct().collect()
acc_df = spark.createDataFrame(sc.emptyRDD(), df.schema) # create an empty df which holds the accumulated results for each category
for idx, c in enumerate(categories):
col_name = c[0].replace(" ", "_") # spark complains for columns containing space
cat_df = df.where(df["category"] == c[0]) \
.groupBy("product_id") \
.agg(
F.collect_list(F.col("purchase_date")).alias("p_date"),
F.collect_list(F.col("days_warranty")).alias("d_warranty")) \
.withColumn(col_name, tuple_list_udf(F.col("p_date"), F.col("d_warranty"))) \
.drop("p_date", "d_warranty")
if idx == 0:
acc_df = cat_df
else:
acc_df = acc_df \
.join(cat_df.alias("cat_df"), "product_id") \
.drop(F.col("cat_df.product_id"))
# you can persist here every buf_size iterations
if idx + 1 % buf_size == 0:
acc_df = acc_df.persist()
The function tuple_list is responsible for generating a list with tuples from purchase_date and days_warranty columns.
def tuple_list(pdl, dwl):
return list(zip(pdl, dwl))
The output of this will be:
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|product_id |CATEGORY_B |CATEGORY_A |
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|02147465400|[[2017-04-16 00:00:00, 90], [2018-09-16 00:00:00, 90], [2017-10-09 00:00:00, 90], [2018-01-12 00:00:00, 90], [2018-07-11 00:00:00, 90], [2017-01-21 00:00:00, 90], [2018-04-14 00:00:00, 90], [2017-01-05 00:00:00, 90], [2017-07-15 00:00:00, 90]]|[[2017-06-14 00:00:00, 30], [2018-08-14 00:00:00, 30], [2018-01-11 00:00:00, 30], [2018-04-12 00:00:00, 30], [2017-10-11 00:00:00, 30], [2017-05-16 00:00:00, 30], [2018-05-15 00:00:00, 30], [2017-04-15 00:00:00, 30], [2017-02-15 00:00:00, 30], [2018-02-12 00:00:00, 30], [2017-01-21 00:00:00, 30], [2018-07-11 00:00:00, 30], [2018-06-14 00:00:00, 30], [2017-03-16 00:00:00, 30], [2017-07-20 00:00:00, 30], [2018-08-23 00:00:00, 30], [2017-09-12 00:00:00, 30], [2018-03-12 00:00:00, 30], [2017-12-12 00:00:00, 30], [2017-08-14 00:00:00, 30], [2017-11-11 00:00:00, 30]]|
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
来源:https://stackoverflow.com/questions/55167913/pyspark-dataframe-cast-two-columns-into-new-column-of-tuples-based-value-of-a-th