How to get other columns when using Spark DataFrame groupby?

前端 未结 7 1706
甜味超标
甜味超标 2020-11-29 22:04

when I use DataFrame groupby like this:

df.groupBy(df(\"age\")).agg(Map(\"id\"->\"count\"))

I will only get a DataFrame with columns \"a

相关标签:
7条回答
  • 2020-11-29 22:33

    Aggregate functions reduce values of rows for specified columns within the group. If you wish to retain other row values you need to implement reduction logic that specifies a row from which each value comes from. For instance keep all values of the first row with the maximum value of age. To this end you can use a UDAF (user defined aggregate function) to reduce rows within the group.

    import org.apache.spark.sql._
    import org.apache.spark.sql.functions._
    
    
    object AggregateKeepingRowJob {
    
      def main (args: Array[String]): Unit = {
    
        val sparkSession = SparkSession
          .builder()
          .appName(this.getClass.getName.replace("$", ""))
          .master("local")
          .getOrCreate()
    
        val sc = sparkSession.sparkContext
        sc.setLogLevel("ERROR")
    
        import sparkSession.sqlContext.implicits._
    
        val rawDf = Seq(
          (1L, "Moe",  "Slap",  2.0, 18),
          (2L, "Larry",  "Spank",  3.0, 15),
          (3L, "Curly",  "Twist", 5.0, 15),
          (4L, "Laurel", "Whimper", 3.0, 15),
          (5L, "Hardy", "Laugh", 6.0, 15),
          (6L, "Charley",  "Ignore",   5.0, 5)
        ).toDF("id", "name", "requisite", "money", "age")
    
        rawDf.show(false)
        rawDf.printSchema
    
        val maxAgeUdaf = new KeepRowWithMaxAge
    
        val aggDf = rawDf
          .groupBy("age")
          .agg(
            count("id"),
            max(col("money")),
            maxAgeUdaf(
              col("id"),
              col("name"),
              col("requisite"),
              col("money"),
              col("age")).as("KeepRowWithMaxAge")
          )
    
        aggDf.printSchema
        aggDf.show(false)
    
      }
    
    
    }
    

    The UDAF:

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
    // This is the input fields for your aggregate function.
    override def inputSchema: org.apache.spark.sql.types.StructType =
      StructType(
        StructField("store", StringType) ::
        StructField("prod", StringType) ::
        StructField("amt", DoubleType) ::
        StructField("units", IntegerType) :: Nil
      )
    
    // This is the internal fields you keep for computing your aggregate.
    override def bufferSchema: StructType = StructType(
      StructField("store", StringType) ::
      StructField("prod", StringType) ::
      StructField("amt", DoubleType) ::
      StructField("units", IntegerType) :: Nil
    )
    
    
    // This is the output type of your aggregation function.
    override def dataType: DataType =
      StructType((Array(
        StructField("store", StringType),
        StructField("prod", StringType),
        StructField("amt", DoubleType),
        StructField("units", IntegerType)
      )))
    
    override def deterministic: Boolean = true
    
    // This is the initial value for your buffer schema.
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = ""
      buffer(1) = ""
      buffer(2) = 0.0
      buffer(3) = 0
    }
    
    // This is how to update your buffer schema given an input.
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
      val amt = buffer.getAs[Double](2)
      val candidateAmt = input.getAs[Double](2)
    
      amt match {
        case a if a < candidateAmt =>
          buffer(0) = input.getAs[String](0)
          buffer(1) = input.getAs[String](1)
          buffer(2) = input.getAs[Double](2)
          buffer(3) = input.getAs[Int](3)
        case _ =>
      }
    }
    
    // This is how to merge two objects with the bufferSchema type.
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
      buffer1(0) = buffer2.getAs[String](0)
      buffer1(1) = buffer2.getAs[String](1)
      buffer1(2) = buffer2.getAs[Double](2)
      buffer1(3) = buffer2.getAs[Int](3)
    }
    
    // This is where you output the final value, given the final value of your bufferSchema.
    override def evaluate(buffer: Row): Any = {
      buffer
    }
    }
    
    0 讨论(0)
提交回复
热议问题