Comparing columns in Pyspark

前端 未结 5 853
Happy的楠姐
Happy的楠姐 2020-12-01 18:17

I am working on a PySpark DataFrame with n columns. I have a set of m columns (m < n) and my task is choose the column with max values in it.

For example:

5条回答
  •  失恋的感觉
    2020-12-01 18:50

    You can reduce using SQL expressions over a list of columns:

    from pyspark.sql.functions import max as max_, col, when
    from functools import reduce
    
    def row_max(*cols):
        return reduce(
            lambda x, y: when(x > y, x).otherwise(y),
            [col(c) if isinstance(c, str) else c for c in cols]
        )
    
    df = (sc.parallelize([(1, 2, 3), (2, 1, 2), (3, 4, 5)])
        .toDF(["a", "b", "c"]))
    
    df.select(row_max("a", "b", "c").alias("max")))
    

    Spark 1.5+ also provides least, greatest

    from pyspark.sql.functions import greatest
    
    df.select(greatest("a", "b", "c"))
    

    If you want to keep name of the max you can use `structs:

    from pyspark.sql.functions import struct, lit
    
    def row_max_with_name(*cols):
        cols_ = [struct(col(c).alias("value"), lit(c).alias("col")) for c in cols]
        return greatest(*cols_).alias("greatest({0})".format(",".join(cols)))
    
     maxs = df.select(row_max_with_name("a", "b", "c").alias("maxs"))
    

    And finally you can use above to find select "top" column:

    from pyspark.sql.functions import max
    
    ((_, c), ) = (maxs
        .groupBy(col("maxs")["col"].alias("col"))
        .count()
        .agg(max(struct(col("count"), col("col"))))
        .first())
    
    df.select(c)
    

提交回复
热议问题