Sum of array elements depending on value condition pyspark

后端 未结 3 979
眼角桃花
眼角桃花 2020-12-17 07:44

I have a pyspark dataframe:

id   |   column
------------------------------
1    |  [0.2, 2, 3, 4, 3, 0.5]
------------------------------
2    |  [7, 0.3, 0.3         


        
相关标签:
3条回答
  • 2020-12-17 08:01

    Here's a way you can try:

    import pyspark.sql.functions as F
    
    # using map filter the list and count based on condition
    s = (df
         .select('column')
         .rdd
         .map(lambda x: [[i for i in x.column if i < 2], 
                         [i for i in x.column if i > 2], 
                         [i for i in x.column if i == 2]])
         .map(lambda x: [Row(round(sum(i), 2)) for i in x]))
         .toDF(['col<2','col>2','col=2'])
    
    # create a dummy id so we can join both data frames
    df = df.withColumn('mid', F.monotonically_increasing_id())
    s = s.withColumn('mid', F.monotonically_increasing_id())
    
    #simple left join
    df = df.join(s, on='mid').drop('mid').show()
    
    +---+--------------------+-----+------+-----+
    | id|              column|col<2| col>2|col=2|
    +---+--------------------+-----+------+-----+
    |  0|[0.2, 2.0, 3.0, 4...|[0.7]|[10.0]|[2.0]|
    |  1|[7.0, 0.3, 0.3, 8...|[0.6]|[15.0]|[2.0]|
    +---+--------------------+-----+------+-----+
    
    0 讨论(0)
  • 2020-12-17 08:05

    For Spark 2.4+, you can use aggregate function and do the calculation in one step:

    from pyspark.sql.functions import expr
    
    # I adjusted the 2nd array-item in id=1 from 2.0 to 2.1 so there is no `2.0` when id=1
    df = spark.createDataFrame([(1,[0.2, 2.1, 3., 4., 3., 0.5]),(2,[7., 0.3, 0.3, 8., 2.,])],['id','column'])
    
    df.withColumn('data', expr("""
    
        aggregate(
          /* ArrayType argument */
          column,
          /* zero: set empty array to initialize acc */
          array(),
          /* merge: iterate through `column` and reduce based on the values of y and the array indices of acc */
          (acc, y) ->
            CASE
              WHEN y < 2.0 THEN array(IFNULL(acc[0],0) + y, acc[1], acc[2])
              WHEN y > 2.0 THEN array(acc[0], IFNULL(acc[1],0) + y, acc[2])
                           ELSE array(acc[0], acc[1], IFNULL(acc[2],0) + y)
            END,
          /* finish: to convert the array into a named_struct */
          acc -> (acc[0] as `column<2`, acc[1] as `column>2`, acc[2] as `column=2`)
        )
    
    """)).selectExpr('id', 'data.*').show()
    #+---+--------+--------+--------+
    #| id|column<2|column>2|column=2|
    #+---+--------+--------+--------+
    #|  1|     0.7|    12.1|    null|
    #|  2|     0.6|    15.0|     2.0|
    #+---+--------+--------+--------+
    

    Before Spark 2.4, the functional-support for ArrayType is limited, you might do it with explode and then groupby+pivot:

    from pyspark.sql.functions import sum as fsum, expr
    
    df.selectExpr('id', 'explode_outer(column) as item') \
      .withColumn('g', expr('if(item < 2, "column<2", if(item > 2, "column>2", "column=2"))')) \
      .groupby('id') \
      .pivot('g', ["column<2", "column>2", "column=2"]) \
      .agg(fsum('item')) \
      .show()
    #+---+--------+--------+--------+                                                
    #| id|column<2|column>2|column=2|
    #+---+--------+--------+--------+
    #|  1|     0.7|    12.1|    null|
    #|  2|     0.6|    15.0|     2.0|
    #+---+--------+--------+--------+
    

    In case explode is slow (i.e. SPARK-21657 shown before Spark 2.3), use an UDF:

    from pyspark.sql.functions import udf
    from pyspark.sql.types import StructType, StructField, DoubleType
    
    schema = StructType([
        StructField("column>2", DoubleType()), 
        StructField("column<2", DoubleType()),
        StructField("column=2", DoubleType())
    ])
    
    def split_data(arr):
       d = {}
       if arr is None: arr = []
       for y in arr:
         if y > 2:
           d['column>2'] = d.get('column>2',0) + y
         elif y < 2:
           d['column<2'] = d.get('column<2',0) + y
         else:
           d['column=2'] = d.get('column=2',0) + y
       return d
    
    udf_split_data = udf(split_data, schema)
    
    df.withColumn('data', udf_split_data('column')).selectExpr('id', 'data.*').show()
    
    0 讨论(0)
  • 2020-12-17 08:10

    For Spark 2.4+, you can use aggregate and filter higher-order functions like this:

    df.withColumn("column<2", expr("aggregate(filter(column, x -> x < 2), 0D, (x, acc) -> acc + x)")) \
      .withColumn("column>2", expr("aggregate(filter(column, x -> x > 2), 0D, (x, acc) -> acc + x)")) \
      .withColumn("column=2", expr("aggregate(filter(column, x -> x == 2), 0D, (x, acc) -> acc + x)")) \
      .show(truncate=False)
    

    Gives:

    +---+------------------------------+--------+--------+--------+
    |id |column                        |column<2|column>2|column=2|
    +---+------------------------------+--------+--------+--------+
    |1  |[0.2, 2.0, 3.0, 4.0, 3.0, 0.5]|0.7     |10.0    |2.0     |
    |2  |[7.0, 0.3, 0.3, 8.0, 2.0]     |0.6     |15.0    |2.0     |
    +---+------------------------------+--------+--------+--------+
    
    0 讨论(0)
提交回复
热议问题