How to do opposite of explode in PySpark?

前端 未结 5 1162
猫巷女王i
猫巷女王i 2020-12-16 15:46

Let\'s say I have a DataFrame with a column for users and another column for words they\'ve written:



        
5条回答
  •  暖寄归人
    2020-12-16 16:24

    As of the spark 2.3 release we now have Pandas UDF(aka Vectorized UDF). The function below will accomplish the OP's task... A benefit of using this function is the order is guaranteed to be preserved. Order is essential in many cases such as time series analysis.

    import pandas as pd
    import findspark
    
    findspark.init()
    import pyspark
    from pyspark.sql import SparkSession, Row
    from pyspark.sql.functions import pandas_udf, PandasUDFType
    from pyspark.sql.types import StructType, StructField, ArrayType
    
    spark = SparkSession.builder.appName('test_collect_array_grouped').getOrCreate()
    
    def collect_array_grouped(df, groupbyCols, aggregateCol, outputCol):
        """
        Aggregate function: returns a new :class:`DataFrame` such that for a given column, aggregateCol,
        in a DataFrame, df, collect into an array the elements for each grouping defined by the groupbyCols list.
        The new DataFrame will have, for each row, the grouping columns and an array of the grouped
        values from aggregateCol in the outputCol.
    
        :param groupbyCols: list of columns to group by.
                Each element should be a column name (string) or an expression (:class:`Column`).
        :param aggregateCol: the column name of the column of values to aggregate into an array
                for each grouping.
        :param outputCol: the column name of the column to output the aggregeted array to.
        """
        groupbyCols = [] if groupbyCols is None else groupbyCols
        df = df.select(groupbyCols + [aggregateCol])
        schema = df.select(groupbyCols).schema
        aggSchema = df.select(aggregateCol).schema
        arrayField = StructField(name=outputCol, dataType=ArrayType(aggSchema[0].dataType, False))
        schema = schema.add(arrayField)
        @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
        def _get_array(pd_df):
            vals = pd_df[groupbyCols].iloc[0].tolist()
            vals.append(pd_df[aggregateCol].values)
            return pd.DataFrame([vals])
        return df.groupby(groupbyCols).apply(_get_array)
    
    rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
                                          Row(user='Bob', word='world'),
                                          Row(user='Mary', word='Have'),
                                          Row(user='Mary', word='a'),
                                          Row(user='Mary', word='nice'),
                                          Row(user='Mary', word='day')])
    df = spark.createDataFrame(rdd)
    
    collect_array_grouped(df, ['user'], 'word', 'users_words').show()
    
    +----+--------------------+
    |user|         users_words|
    +----+--------------------+
    |Mary|[Have, a, nice, day]|
    | Bob|      [hello, world]|
    +----+--------------------+
    

提交回复
热议问题