How to define a custom aggregation function to sum a column of Vectors?

[亡魂溺海] 提交于 2019-11-26 05:29:31

问题


I have a DataFrame of two columns, ID of type Int and Vec of type Vector (org.apache.spark.mllib.linalg.Vector).

The DataFrame looks like follow:

ID,Vec
1,[0,0,5]
1,[4,0,1]
1,[1,2,1]
2,[7,5,0]
2,[3,3,4]
3,[0,8,1]
3,[0,0,1]
3,[7,7,7]
....

I would like to do a groupBy($\"ID\") then apply an aggregation on the rows inside each group by summing the vectors.

The desired output of the above example would be:

ID,SumOfVectors
1,[5,2,7]
2,[10,8,4]
3,[7,15,9]
...

The available aggregation functions will not work, e.g. df.groupBy($\"ID\").agg(sum($\"Vec\") will lead to an ClassCastException.

How to implement a custom aggregation function that allows me to do the sum of vectors or arrays or any other custom operation?


回答1:


Personally I wouldn't bother with UDAFs. There are more than verbose and not exactly fast (Spark UDAF with ArrayType as bufferSchema performance issues) Instead I would simply use reduceByKey / foldByKey:

import org.apache.spark.sql.Row
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.ml.linalg.{Vector, Vectors}

def dv(values: Double*): Vector = Vectors.dense(values.toArray)

val df = spark.createDataFrame(Seq(
    (1, dv(0,0,5)), (1, dv(4,0,1)), (1, dv(1,2,1)),
    (2, dv(7,5,0)), (2, dv(3,3,4)), 
    (3, dv(0,8,1)), (3, dv(0,0,1)), (3, dv(7,7,7)))
  ).toDF("id", "vec")

val aggregated = df
  .rdd
  .map{ case Row(k: Int, v: Vector) => (k, BDV(v.toDense.values)) }
  .foldByKey(BDV.zeros[Double](3))(_ += _)
  .mapValues(v => Vectors.dense(v.toArray))
  .toDF("id", "vec")

aggregated.show

// +---+--------------+
// | id|           vec|
// +---+--------------+
// |  1| [5.0,2.0,7.0]|
// |  2|[10.0,8.0,4.0]|
// |  3|[7.0,15.0,9.0]|
// +---+--------------+

And just for comparison a "simple" UDAF. Required imports:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
  UserDefinedAggregateFunction}
import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes}
import org.apache.spark.sql.types.{StructType, ArrayType, DoubleType}
import org.apache.spark.sql.Row
import scala.collection.mutable.WrappedArray

Class definition:

class VectorSum (n: Int) extends UserDefinedAggregateFunction {
    def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
    def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
    def dataType = SQLDataTypes.VectorType
    def deterministic = true 

    def initialize(buffer: MutableAggregationBuffer) = {
      buffer.update(0, Array.fill(n)(0.0))
    }

    def update(buffer: MutableAggregationBuffer, input: Row) = {
      if (!input.isNullAt(0)) {
        val buff = buffer.getAs[WrappedArray[Double]](0) 
        val v = input.getAs[Vector](0).toSparse
        for (i <- v.indices) {
          buff(i) += v(i)
        }
        buffer.update(0, buff)
      }
    }

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      val buff1 = buffer1.getAs[WrappedArray[Double]](0) 
      val buff2 = buffer2.getAs[WrappedArray[Double]](0) 
      for ((x, i) <- buff2.zipWithIndex) {
        buff1(i) += x
      }
      buffer1.update(0, buff1)
    }

    def evaluate(buffer: Row) =  Vectors.dense(
      buffer.getAs[Seq[Double]](0).toArray)
} 

And an example usage:

df.groupBy($"id").agg(new VectorSum(3)($"vec") alias "vec").show

// +---+--------------+
// | id|           vec|
// +---+--------------+
// |  1| [5.0,2.0,7.0]|
// |  2|[10.0,8.0,4.0]|
// |  3|[7.0,15.0,9.0]|
// +---+--------------+

See also: How to find mean of grouped Vector columns in Spark SQL?.




回答2:


I suggest the following (works on Spark 2.0.2 onward), it might be optimized but it's very nice, one thing you have to know in advance is the vector size when you create the UDAF instance

import org.apache.spark.ml.linalg._
import org.apache.spark.mllib.linalg.WeightedSparseVector
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class VectorAggregate(val numFeatures: Int)
   extends UserDefinedAggregateFunction {

private type B = Map[Int, Double]

def inputSchema: StructType = StructType(StructField("vec", new VectorUDT()) :: Nil)

def bufferSchema: StructType =
StructType(StructField("agg", MapType(IntegerType, DoubleType)) :: Nil)

def initialize(buffer: MutableAggregationBuffer): Unit =
buffer.update(0, Map.empty[Int, Double])

def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val zero = buffer.getAs[B](0)
    input match {
        case Row(DenseVector(values)) => buffer.update(0, values.zipWithIndex.foldLeft(zero){case (acc,(v,i)) => acc.updated(i, v + acc.getOrElse(i,0d))})
        case Row(SparseVector(_, indices, values)) => buffer.update(0, values.zip(indices).foldLeft(zero){case (acc,(v,i)) => acc.updated(i, v + acc.getOrElse(i,0d))}) }}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val zero = buffer1.getAs[B](0)
buffer1.update(0, buffer2.getAs[B](0).foldLeft(zero){case (acc,(i,v)) => acc.updated(i, v + acc.getOrElse(i,0d))})}

def deterministic: Boolean = true

def evaluate(buffer: Row): Any = {
    val Row(agg: B) = buffer
    val indices = agg.keys.toArray.sorted
    Vectors.sparse(numFeatures,indices,indices.map(agg)).compressed
}

def dataType: DataType = new VectorUDT()
}


来源:https://stackoverflow.com/questions/33899977/how-to-define-a-custom-aggregation-function-to-sum-a-column-of-vectors

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!