Scala/Spark dataframes: find the column name corresponding to the max

后端 未结 1 1437
栀梦
栀梦 2020-12-10 17:40

In Scala/Spark, having a dataframe:

val dfIn = sqlContext.createDataFrame(Seq(
  (\"r0\", 0, 2, 3),
  (\"r1\", 1, 0, 0),
  (\"r2\", 0, 2, 2))).toDF(\"id\", \         


        
相关标签:
1条回答
  • 2020-12-10 18:31

    With a small trick you can use greatest function. Required imports:

    import org.apache.spark.sql.functions.{col, greatest, lit, struct}
    

    First let's create a list of structs, where the first element is value, and the second one column name:

    val structs = dfIn.columns.tail.map(
      c => struct(col(c).as("v"), lit(c).as("k"))
    )
    

    Structure like this can be passed to greatest as follows:

    dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))
    
    +---+---+---+---+------+
    | id| c0| c1| c2|maxCol|
    +---+---+---+---+------+
    | r0|  0|  2|  3|    c2|
    | r1|  1|  0|  0|    c0|
    | r2|  0|  2|  2|    c2|
    +---+---+---+---+------+
    

    Please note that in case of ties it will take the element which occurs later in the sequence (lexicographically (x, "c2") > (x, "c1")). If for some reason this is not acceptable you can explicitly reduce with when:

    import org.apache.spark.sql.functions.when
    
    val max_col = structs.reduce(
      (c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
    ).getItem("k")
    
    dfIn.withColumn("maxCol", max_col)
    
    +---+---+---+---+------+
    | id| c0| c1| c2|maxCol|
    +---+---+---+---+------+
    | r0|  0|  2|  3|    c2|
    | r1|  1|  0|  0|    c0|
    | r2|  0|  2|  2|    c1|
    +---+---+---+---+------+
    

    In case of nullable columns you have to adjust this, for example by coalescing to values to -Inf.

    0 讨论(0)
提交回复
热议问题