Add column sum as new column in PySpark dataframe

前端 未结 8 2000
粉色の甜心
粉色の甜心 2020-12-02 22:43

I\'m using PySpark and I have a Spark dataframe with a bunch of numeric columns. I want to add a column that is the sum of all the other columns.

Suppose my datafram

相关标签:
8条回答
  • 2020-12-02 23:24

    Summing multiple columns from a list into one column

    PySpark's sum function doesn't support column addition. This can be achieved using expr function.

    from pyspark.sql.functions import expr
    
    cols_list = ['a', 'b', 'c']
    
    # Creating an addition expression using `join`
    expression = '+'.join(cols_list)
    
    df = df.withColumn('sum_cols', expr(expression))
    

    This gives us the desired sum of columns.

    0 讨论(0)
  • 2020-12-02 23:25

    This was not obvious. I see no row-based sum of the columns defined in the spark Dataframes API.

    Version 2

    This can be done in a fairly simple way:

    newdf = df.withColumn('total', sum(df[col] for col in df.columns))
    

    df.columns is supplied by pyspark as a list of strings giving all of the column names in the Spark Dataframe. For a different sum, you can supply any other list of column names instead.

    I did not try this as my first solution because I wasn't certain how it would behave. But it works.

    Version 1

    This is overly complicated, but works as well.

    You can do this:

    1. use df.columns to get a list of the names of the columns
    2. use that names list to make a list of the columns
    3. pass that list to something that will invoke the column's overloaded add function in a fold-type functional manner

    With python's reduce, some knowledge of how operator overloading works, and the pyspark code for columns here that becomes:

    def column_add(a,b):
         return  a.__add__(b)
    
    newdf = df.withColumn('total_col', 
             reduce(column_add, ( df[col] for col in df.columns ) ))
    

    Note this is a python reduce, not a spark RDD reduce, and the parenthesis term in the second parameter to reduce requires the parenthesis because it is a list generator expression.

    Tested, Works!

    $ pyspark
    >>> df = sc.parallelize([{'a': 1, 'b':2, 'c':3}, {'a':8, 'b':5, 'c':6}, {'a':3, 'b':1, 'c':0}]).toDF().cache()
    >>> df
    DataFrame[a: bigint, b: bigint, c: bigint]
    >>> df.columns
    ['a', 'b', 'c']
    >>> def column_add(a,b):
    ...     return a.__add__(b)
    ...
    >>> df.withColumn('total', reduce(column_add, ( df[col] for col in df.columns ) )).collect()
    [Row(a=1, b=2, c=3, total=6), Row(a=8, b=5, c=6, total=19), Row(a=3, b=1, c=0, total=4)]
    
    0 讨论(0)
提交回复
热议问题