问题
I have a spark dataframe that has an ID column and along with other columns, it has an array column that contains the IDs of its related records, as its value.
example dataframe will be of
ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456]
345 | alen | [789]
456 | sam | [789,999]
789 | marc | [111]
555 | dan | [333]
From the above, I need to append all the related child Id's to the array column of the parent ID. The resultant DF should be like
ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456,789,999,111]
345 | alen | [789,111]
456 | sam | [789,999,111]
789 | marc | [111]
555 | dan | [333]
need help on how to do it. thanks
回答1:
One way to handle this task is to do self leftjoin, update the RELATED_IDLIST, do this several iterations until some conditions satisfy (this works only when the max-depth of the whole hierarchy is small). For Spark 2.3, we can convert the ArrayType column into a comma-delimitered StringType column, use SQL builtin function find_in_set and a new column PROCESSED_IDLIST
to set up the join-conditions, see below for the main logic:
Functions:
from pyspark.sql import functions as F
import pandas as pd
# define a function which takes a dataframe as input, does a self left-join and then return another
# dataframe with exactly the same schema as the input dataframe. do the same repeatly until some conditions satisfy
def recursive_join(d, max_iter=10):
# function to find direct child-IDs and merge into RELATED_IDLIST
def find_child_idlist(_df):
return _df.alias('d1').join(
_df.alias('d2'),
F.expr("find_in_set(d2.ID,d1.RELATED_IDLIST)>0 AND find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1"),
"left"
).groupby("d1.ID", "d1.NAME").agg(
F.expr("""
/* combine d1.RELATED_IDLIST with all matched entries from collect_set(d2.RELATED_IDLIST)
* and remove trailing comma from when all d2.RELATED_IDLIST are NULL */
trim(TRAILING ',' FROM
concat_ws(",", first(d1.RELATED_IDLIST), concat_ws(",", collect_list(d2.RELATED_IDLIST)))
) as RELATED_IDLIST"""),
F.expr("first(d1.RELATED_IDLIST) as PROCESSED_IDLIST")
)
# below the main code logic
d = find_child_idlist(d).persist()
if (d.filter("RELATED_IDLIST!=PROCESSED_IDLIST").count() > 0) & (max_iter > 1):
d = recursive_join(d, max_iter-1)
return d
# define pandas_udf to remove duplicate from an ArrayType column
get_uniq = F.pandas_udf(lambda s: pd.Series([ list(set(x)) for x in s ]), "array<int>")
Where:
in the function
find_child_idlist()
, the left-join must satisfy the following two conditions:- d2.ID is in d1.RELATED_IDLIST:
find_in_set(d2.ID,d1.RELATED_IDLIST)>0
- d2.ID not in d1.PROCESSED_IDLIST:
find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1
- d2.ID is in d1.RELATED_IDLIST:
quit the recursive_join when no row satisfying
RELATED_IDLIST!=PROCESSED_IDLIST
ormax_iter > 1
Processing:
set up dataframe:
df = spark.createDataFrame([ (123, "mike", [345,456]), (345, "alen", [789]), (456, "sam", [789,999]), (789, "marc", [111]), (555, "dan", [333]) ],["ID", "NAME", "RELATED_IDLIST"])
add a new column
PROCESSED_IDLIST
to saveRELATED_IDLIST
in the previous join, and dorecursive_join()
df1 = df.withColumn('RELATED_IDLIST', F.concat_ws(',','RELATED_IDLIST')) \ .withColumn('PROCESSED_IDLIST', F.col('ID')) df_new = recursive_join(df1, 5) df_new.show(10,0) +---+----+-----------------------+-----------------------+ |ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST | +---+----+-----------------------+-----------------------+ |555|dan |333 |333 | |789|marc|111 |111 | |345|alen|789,111 |789,111 | |123|mike|345,456,789,789,999,111|345,456,789,789,999,111| |456|sam |789,999,111 |789,999,111 | +---+----+-----------------------+-----------------------+
split
RELATED_IDLIST
into array of integers and then use pandas_udf function to drop duplicate array elements:df_new.withColumn("RELATED_IDLIST", get_uniq(F.split('RELATED_IDLIST', ',').cast('array<int>'))).show(10,0) +---+----+-------------------------+-----------------------+ |ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST | +---+----+-------------------------+-----------------------+ |555|dan |[333] |333 | |789|marc|[111] |111 | |345|alen|[789, 111] |789,111 | |123|mike|[999, 456, 111, 789, 345]|345,456,789,789,999,111| |456|sam |[111, 789, 999] |789,999,111 | +---+----+-------------------------+-----------------------+
来源:https://stackoverflow.com/questions/64921362/pyspark-get-related-records-from-its-array-object-values