PySpark Dataframe cast two columns into new column of tuples based value of a third column

纵饮孤独 提交于 2019-12-04 18:49:13

Assuming your Dataframe is called df:

from pyspark.sql.functions import struct
from pyspark.sql.functions import collect_list

gdf = (df.select("product_id", "category", struct("purchase_date", "warranty_days").alias("pd_wd"))
.groupBy("product_id")
.pivot("category")
.agg(collect_list("pd_wd")))

Essentially, you have to group the purchase_date and warranty_days into a single column using struct(). Then, you are just grouping by product_id, pivoting by category, can aggregating as collect_list().

In the case that you have performance issues with pivot the approach below is another solution to the same problem although it allows you to have more control by splitting the job into phases for each category with a for loop. For every iteration this will append the new data for the category_x into acc_df which will hold the accumulated results.

schema = ArrayType( 
        StructType((  
            StructField("p_date", StringType(), False), 
            StructField("d_warranty", StringType(), False)  
        )) 
    )

    tuple_list_udf = udf(tuple_list, schema)

    buf_size = 5 # if you get OOM error decrease this to persist more often

    categories = df.select("category").distinct().collect()

    acc_df = spark.createDataFrame(sc.emptyRDD(), df.schema) # create an empty df which holds the accumulated results for each category

    for idx, c in enumerate(categories):
        col_name = c[0].replace(" ", "_") # spark complains for columns containing space
        cat_df = df.where(df["category"] == c[0]) \
                .groupBy("product_id") \
                .agg(
                    F.collect_list(F.col("purchase_date")).alias("p_date"), 
                    F.collect_list(F.col("days_warranty")).alias("d_warranty")) \
                .withColumn(col_name, tuple_list_udf(F.col("p_date"), F.col("d_warranty"))) \
                .drop("p_date", "d_warranty")

        if idx == 0:
            acc_df = cat_df
        else:
            acc_df = acc_df \
                .join(cat_df.alias("cat_df"), "product_id") \
                .drop(F.col("cat_df.product_id"))

        # you can persist here every buf_size iterations
        if idx + 1 % buf_size == 0:
            acc_df = acc_df.persist()

The function tuple_list is responsible for generating a list with tuples from purchase_date and days_warranty columns.

def tuple_list(pdl, dwl):
    return list(zip(pdl, dwl))

The output of this will be:

+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|product_id |CATEGORY_B                                                                                                                                                                                                                                         |CATEGORY_A                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             |
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|02147465400|[[2017-04-16 00:00:00, 90], [2018-09-16 00:00:00, 90], [2017-10-09 00:00:00, 90], [2018-01-12 00:00:00, 90], [2018-07-11 00:00:00, 90], [2017-01-21 00:00:00, 90], [2018-04-14 00:00:00, 90], [2017-01-05 00:00:00, 90], [2017-07-15 00:00:00, 90]]|[[2017-06-14 00:00:00, 30], [2018-08-14 00:00:00, 30], [2018-01-11 00:00:00, 30], [2018-04-12 00:00:00, 30], [2017-10-11 00:00:00, 30], [2017-05-16 00:00:00, 30], [2018-05-15 00:00:00, 30], [2017-04-15 00:00:00, 30], [2017-02-15 00:00:00, 30], [2018-02-12 00:00:00, 30], [2017-01-21 00:00:00, 30], [2018-07-11 00:00:00, 30], [2018-06-14 00:00:00, 30], [2017-03-16 00:00:00, 30], [2017-07-20 00:00:00, 30], [2018-08-23 00:00:00, 30], [2017-09-12 00:00:00, 30], [2018-03-12 00:00:00, 30], [2017-12-12 00:00:00, 30], [2017-08-14 00:00:00, 30], [2017-11-11 00:00:00, 30]]|
+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!