How to get other columns when using Spark DataFrame groupby?

前端 未结 7 1704
甜味超标
甜味超标 2020-11-29 22:04

when I use DataFrame groupby like this:

df.groupBy(df(\"age\")).agg(Map(\"id\"->\"count\"))

I will only get a DataFrame with columns \"a

7条回答
  •  南方客
    南方客 (楼主)
    2020-11-29 22:31

    You need to remember that aggregate functions reduce the rows and therefore you need to specify which of the rows name you want with a reducing function. If you want to retain all rows of a group (warning! this can cause explosions or skewed partitions) you can collect them as a list. You can then use a UDF (user defined function) to reduce them by your criteria, in my example money. And then expand columns from the single reduced row with another UDF . For the purpose of this answer I assume you wish to retain the name of the person who has the most money.

    import org.apache.spark.sql._
    import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types.StringType
    
    import scala.collection.mutable
    
    
    object TestJob3 {
    
    def main (args: Array[String]): Unit = {
    
    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()
    
    val sc = sparkSession.sparkContext
    
    import sparkSession.sqlContext.implicits._
    
    val rawDf = Seq(
      (1, "Moe",  "Slap",  2.0, 18),
      (2, "Larry",  "Spank",  3.0, 15),
      (3, "Curly",  "Twist", 5.0, 15),
      (4, "Laurel", "Whimper", 3.0, 9),
      (5, "Hardy", "Laugh", 6.0, 18),
      (6, "Charley",  "Ignore",   5.0, 5)
    ).toDF("id", "name", "requisite", "money", "age")
    
    rawDf.show(false)
    rawDf.printSchema
    
    val rawSchema = rawDf.schema
    
    val fUdf = udf(reduceByMoney, rawSchema)
    
    val nameUdf = udf(extractName, StringType)
    
    val aggDf = rawDf
      .groupBy("age")
      .agg(
        count(struct("*")).as("count"),
        max(col("money")),
        collect_list(struct("*")).as("horizontal")
      )
      .withColumn("short", fUdf($"horizontal"))
      .withColumn("name", nameUdf($"short"))
      .drop("horizontal")
    
    aggDf.printSchema
    
    aggDf.show(false)
    
    }
    
    def reduceByMoney= (x: Any) => {
    
    val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
    
    val red = d.reduce((r1, r2) => {
    
      val money1 = r1.getAs[Double]("money")
      val money2 = r2.getAs[Double]("money")
    
      val r3 = money1 match {
        case a if a >= money2 =>
          r1
        case _ =>
          r2
      }
    
      r3
    })
    
    red
    }
    
    def extractName = (x: Any) => {
    
      val d = x.asInstanceOf[GenericRowWithSchema]
    
      d.getAs[String]("name")
    }
    }
    

    here is the output

    +---+-----+----------+----------------------------+-------+
    |age|count|max(money)|short                       |name   |
    +---+-----+----------+----------------------------+-------+
    |5  |1    |5.0       |[6, Charley, Ignore, 5.0, 5]|Charley|
    |15 |2    |5.0       |[3, Curly, Twist, 5.0, 15]  |Curly  |
    |9  |1    |3.0       |[4, Laurel, Whimper, 3.0, 9]|Laurel |
    |18 |2    |6.0       |[5, Hardy, Laugh, 6.0, 18]  |Hardy  |
    +---+-----+----------+----------------------------+-------+
    

提交回复
热议问题