How to define a custom aggregation function to sum a column of Vectors?

前端 未结 2 1516
闹比i
闹比i 2020-11-27 04:23

I have a DataFrame of two columns, ID of type Int and Vec of type Vector (org.apache.spark.mllib.linalg.Vector

2条回答
  •  误落风尘
    2020-11-27 05:06

    Spark >= 3.0

    You can use Summarizer with sum

    import org.apache.spark.ml.stat.Summarizer
    
    df
      .groupBy($"id")
      .agg(Summarizer.sum($"vec").alias("vec"))
    

    Spark <= 3.0

    Personally I wouldn't bother with UDAFs. There are more than verbose and not exactly fast (Spark UDAF with ArrayType as bufferSchema performance issues) Instead I would simply use reduceByKey / foldByKey:

    import org.apache.spark.sql.Row
    import breeze.linalg.{DenseVector => BDV}
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    
    def dv(values: Double*): Vector = Vectors.dense(values.toArray)
    
    val df = spark.createDataFrame(Seq(
        (1, dv(0,0,5)), (1, dv(4,0,1)), (1, dv(1,2,1)),
        (2, dv(7,5,0)), (2, dv(3,3,4)), 
        (3, dv(0,8,1)), (3, dv(0,0,1)), (3, dv(7,7,7)))
      ).toDF("id", "vec")
    
    val aggregated = df
      .rdd
      .map{ case Row(k: Int, v: Vector) => (k, BDV(v.toDense.values)) }
      .foldByKey(BDV.zeros[Double](3))(_ += _)
      .mapValues(v => Vectors.dense(v.toArray))
      .toDF("id", "vec")
    
    aggregated.show
    
    // +---+--------------+
    // | id|           vec|
    // +---+--------------+
    // |  1| [5.0,2.0,7.0]|
    // |  2|[10.0,8.0,4.0]|
    // |  3|[7.0,15.0,9.0]|
    // +---+--------------+
    

    And just for comparison a "simple" UDAF. Required imports:

    import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
      UserDefinedAggregateFunction}
    import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes}
    import org.apache.spark.sql.types.{StructType, ArrayType, DoubleType}
    import org.apache.spark.sql.Row
    import scala.collection.mutable.WrappedArray
    

    Class definition:

    class VectorSum (n: Int) extends UserDefinedAggregateFunction {
        def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
        def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
        def dataType = SQLDataTypes.VectorType
        def deterministic = true 
    
        def initialize(buffer: MutableAggregationBuffer) = {
          buffer.update(0, Array.fill(n)(0.0))
        }
    
        def update(buffer: MutableAggregationBuffer, input: Row) = {
          if (!input.isNullAt(0)) {
            val buff = buffer.getAs[WrappedArray[Double]](0) 
            val v = input.getAs[Vector](0).toSparse
            for (i <- v.indices) {
              buff(i) += v(i)
            }
            buffer.update(0, buff)
          }
        }
    
        def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
          val buff1 = buffer1.getAs[WrappedArray[Double]](0) 
          val buff2 = buffer2.getAs[WrappedArray[Double]](0) 
          for ((x, i) <- buff2.zipWithIndex) {
            buff1(i) += x
          }
          buffer1.update(0, buff1)
        }
    
        def evaluate(buffer: Row) =  Vectors.dense(
          buffer.getAs[Seq[Double]](0).toArray)
    } 
    

    And an example usage:

    df.groupBy($"id").agg(new VectorSum(3)($"vec") alias "vec").show
    
    // +---+--------------+
    // | id|           vec|
    // +---+--------------+
    // |  1| [5.0,2.0,7.0]|
    // |  2|[10.0,8.0,4.0]|
    // |  3|[7.0,15.0,9.0]|
    // +---+--------------+
    

    See also: How to find mean of grouped Vector columns in Spark SQL?.

提交回复
热议问题