How to do opposite of explode in PySpark?

前端 未结 5 1145
猫巷女王i
猫巷女王i 2020-12-16 15:46

Let\'s say I have a DataFrame with a column for users and another column for words they\'ve written:



        
相关标签:
5条回答
  • 2020-12-16 16:24

    As of the spark 2.3 release we now have Pandas UDF(aka Vectorized UDF). The function below will accomplish the OP's task... A benefit of using this function is the order is guaranteed to be preserved. Order is essential in many cases such as time series analysis.

    import pandas as pd
    import findspark
    
    findspark.init()
    import pyspark
    from pyspark.sql import SparkSession, Row
    from pyspark.sql.functions import pandas_udf, PandasUDFType
    from pyspark.sql.types import StructType, StructField, ArrayType
    
    spark = SparkSession.builder.appName('test_collect_array_grouped').getOrCreate()
    
    def collect_array_grouped(df, groupbyCols, aggregateCol, outputCol):
        """
        Aggregate function: returns a new :class:`DataFrame` such that for a given column, aggregateCol,
        in a DataFrame, df, collect into an array the elements for each grouping defined by the groupbyCols list.
        The new DataFrame will have, for each row, the grouping columns and an array of the grouped
        values from aggregateCol in the outputCol.
    
        :param groupbyCols: list of columns to group by.
                Each element should be a column name (string) or an expression (:class:`Column`).
        :param aggregateCol: the column name of the column of values to aggregate into an array
                for each grouping.
        :param outputCol: the column name of the column to output the aggregeted array to.
        """
        groupbyCols = [] if groupbyCols is None else groupbyCols
        df = df.select(groupbyCols + [aggregateCol])
        schema = df.select(groupbyCols).schema
        aggSchema = df.select(aggregateCol).schema
        arrayField = StructField(name=outputCol, dataType=ArrayType(aggSchema[0].dataType, False))
        schema = schema.add(arrayField)
        @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
        def _get_array(pd_df):
            vals = pd_df[groupbyCols].iloc[0].tolist()
            vals.append(pd_df[aggregateCol].values)
            return pd.DataFrame([vals])
        return df.groupby(groupbyCols).apply(_get_array)
    
    rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
                                          Row(user='Bob', word='world'),
                                          Row(user='Mary', word='Have'),
                                          Row(user='Mary', word='a'),
                                          Row(user='Mary', word='nice'),
                                          Row(user='Mary', word='day')])
    df = spark.createDataFrame(rdd)
    
    collect_array_grouped(df, ['user'], 'word', 'users_words').show()
    
    +----+--------------------+
    |user|         users_words|
    +----+--------------------+
    |Mary|[Have, a, nice, day]|
    | Bob|      [hello, world]|
    +----+--------------------+
    
    0 讨论(0)
  • 2020-12-16 16:25

    Thanks to @titipat for giving the RDD solution. I did realize shortly after my post that there is actually a DataFrame solution using collect_set (or collect_list):

    from pyspark.sql import Row
    from pyspark.sql.functions import collect_set
    rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
                                          Row(user='Bob', word='world'),
                                          Row(user='Mary', word='Have'),
                                          Row(user='Mary', word='a'),
                                          Row(user='Mary', word='nice'),
                                          Row(user='Mary', word='day')])
    df = spark.createDataFrame(rdd)
    group_user = df.groupBy('user').agg(collect_set('word').alias('words'))
    print(group_user.collect())
    
    >[Row(user='Mary', words=['Have', 'nice', 'day', 'a']), Row(user='Bob', words=['world', 'hello'])]
    
    0 讨论(0)
  • 2020-12-16 16:27

    Here is a solution using rdd.

    from pyspark.sql import Row
    rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
                                          Row(user='Bob', word='world'),
                                          Row(user='Mary', word='Have'),
                                          Row(user='Mary', word='a'),
                                          Row(user='Mary', word='nice'),
                                          Row(user='Mary', word='day')])
    group_user = rdd.groupBy(lambda x: x.user)
    group_agg = group_user.map(lambda x: Row(**{'user': x[0], 'word': [t.word for t in x[1]]}))
    

    Output from group_agg.collect():

    [Row(user='Bob', word=['hello', 'world']),
    Row(user='Mary', word=['Have', 'a', 'nice', 'day'])]
    
    0 讨论(0)
  • 2020-12-16 16:29
    from pyspark.sql import functions as F
    
    df.groupby("user").agg(F.collect_list("word"))
    
    0 讨论(0)
  • 2020-12-16 16:32

    You have a native aggregate function for that, collect_set (docs here).

    Then, you could use:

    from pyspark.sql import functions as F
    df.groupby("user").agg(F.collect_set("word"))
    
    0 讨论(0)
提交回复
热议问题