PySpark get related records from its array object values

巧了我就是萌 提交于 2020-12-13 03:12:44

问题


I have a spark dataframe that has an ID column and along with other columns, it has an array column that contains the IDs of its related records, as its value.

example dataframe will be of

ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456]
345 | alen | [789]
456 | sam  | [789,999]
789 | marc | [111]
555 | dan  | [333]

From the above, I need to append all the related child Id's to the array column of the parent ID. The resultant DF should be like

ID | NAME | RELATED_IDLIST
 --------------------------
 123 | mike | [345,456,789,999,111]
 345 | alen | [789,111]
 456 | sam  | [789,999,111]
 789 | marc | [111]
 555 | dan  | [333]
 

need help on how to do it. thanks


回答1:


One way to handle this task is to do self leftjoin, update the RELATED_IDLIST, do this several iterations until some conditions satisfy (this works only when the max-depth of the whole hierarchy is small). For Spark 2.3, we can convert the ArrayType column into a comma-delimitered StringType column, use SQL builtin function find_in_set and a new column PROCESSED_IDLIST to set up the join-conditions, see below for the main logic:

Functions:

from pyspark.sql import functions as F
import pandas as pd

# define a function which takes a dataframe as input, does a self left-join and then return another 
# dataframe with exactly the same schema as the input dataframe. do the same repeatly until some conditions satisfy
def recursive_join(d, max_iter=10):
  # function to find direct child-IDs and merge into RELATED_IDLIST
  def find_child_idlist(_df):
    return _df.alias('d1').join(
        _df.alias('d2'), 
        F.expr("find_in_set(d2.ID,d1.RELATED_IDLIST)>0 AND find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1"),
        "left"
      ).groupby("d1.ID", "d1.NAME").agg(
        F.expr("""
          /* combine d1.RELATED_IDLIST with all matched entries from collect_set(d2.RELATED_IDLIST)
           * and remove trailing comma from when all d2.RELATED_IDLIST are NULL */
          trim(TRAILING ',' FROM
              concat_ws(",", first(d1.RELATED_IDLIST), concat_ws(",", collect_list(d2.RELATED_IDLIST)))
          ) as RELATED_IDLIST"""),
        F.expr("first(d1.RELATED_IDLIST) as PROCESSED_IDLIST")
    )
  # below the main code logic
  d = find_child_idlist(d).persist()
  if (d.filter("RELATED_IDLIST!=PROCESSED_IDLIST").count() > 0) & (max_iter > 1):
    d = recursive_join(d, max_iter-1)
  return d

# define pandas_udf to remove duplicate from an ArrayType column
get_uniq = F.pandas_udf(lambda s: pd.Series([ list(set(x)) for x in s ]), "array<int>")

Where:

  1. in the function find_child_idlist(), the left-join must satisfy the following two conditions:

    • d2.ID is in d1.RELATED_IDLIST: find_in_set(d2.ID,d1.RELATED_IDLIST)>0
    • d2.ID not in d1.PROCESSED_IDLIST: find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1
  2. quit the recursive_join when no row satisfying RELATED_IDLIST!=PROCESSED_IDLIST or max_iter > 1

Processing:

  1. set up dataframe:

    df = spark.createDataFrame([
      (123, "mike", [345,456]), (345, "alen", [789]), (456, "sam", [789,999]),
      (789, "marc", [111]), (555, "dan", [333])
    ],["ID", "NAME", "RELATED_IDLIST"])
    
  2. add a new column PROCESSED_IDLIST to save RELATED_IDLIST in the previous join, and do recursive_join()

    df1 = df.withColumn('RELATED_IDLIST', F.concat_ws(',','RELATED_IDLIST')) \
        .withColumn('PROCESSED_IDLIST', F.col('ID'))
    
    df_new = recursive_join(df1, 5)
    df_new.show(10,0)
    +---+----+-----------------------+-----------------------+
    |ID |NAME|RELATED_IDLIST         |PROCESSED_IDLIST       |
    +---+----+-----------------------+-----------------------+
    |555|dan |333                    |333                    |
    |789|marc|111                    |111                    |
    |345|alen|789,111                |789,111                |
    |123|mike|345,456,789,789,999,111|345,456,789,789,999,111|
    |456|sam |789,999,111            |789,999,111            |
    +---+----+-----------------------+-----------------------+
    
  3. split RELATED_IDLIST into array of integers and then use pandas_udf function to drop duplicate array elements:

    df_new.withColumn("RELATED_IDLIST", get_uniq(F.split('RELATED_IDLIST', ',').cast('array<int>'))).show(10,0)
    +---+----+-------------------------+-----------------------+
    |ID |NAME|RELATED_IDLIST           |PROCESSED_IDLIST       |
    +---+----+-------------------------+-----------------------+
    |555|dan |[333]                    |333                    |
    |789|marc|[111]                    |111                    |
    |345|alen|[789, 111]               |789,111                |
    |123|mike|[999, 456, 111, 789, 345]|345,456,789,789,999,111|
    |456|sam |[111, 789, 999]          |789,999,111            |
    +---+----+-------------------------+-----------------------+
    


来源:https://stackoverflow.com/questions/64921362/pyspark-get-related-records-from-its-array-object-values

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!