Spark UDAF - using generics as input type?

寵の児 提交于 2019-12-01 00:46:21

For simplicity let's assume you want to define a custom sum. You'll have provide a TypeTag for the input type and use Scala reflection to define schemas:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor

case class MySum [T : TypeTag](implicit n: Numeric[T]) 
    extends UserDefinedAggregateFunction {

  val dt = schemaFor[T].dataType
  def inputSchema = new StructType().add("x", dt)
  def bufferSchema = new StructType().add("x", dt)

  def dataType = dt
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = buffer.update(0,  n.zero)
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, n.plus(buffer.getAs[T](0), input.getAs[T](0)))
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, n.plus(buffer1.getAs[T](0),  buffer2.getAs[T](0)))    
  }

  def evaluate(buffer: Row) = buffer.getAs[T](0)
}

With a function defined as above we can create instance handling specific types:

val sumOfLong = MySum[Long]
spark.range(10).select(sumOfLong($"id")).show
+---------+
|mysum(id)|
+---------+
|       45|
+---------+

Note:

To get the same flexibility as the built-in aggregate functions you'd have to define your own AggregateFunction, like ImperativeAggregate or DeclarativeAggregate. It is possible, but it is an internal API.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!