Returning Multiple Arrays from User-Defined Aggregate Function (UDAF) in Apache Spark SQL

点点圈 提交于 2019-12-30 03:13:10

问题


I am trying to create a user-defined aggregate function (UDAF) in Java using Apache Spark SQL that returns multiple arrays on completion. I have searched online and cannot find any examples or suggestions on how to do this.

I am able to return a single array, but cannot figure out how to get the data in the correct format in the evaluate() method for returning multiple arrays.

The UDAF does work as I can print out the arrays in the evaluate() method, I just can't figure out how to return those arrays to the calling code (which is shown below for reference).

UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");

I have included the whole custom UDAF class below, but the key methods are the dataType() and evaluate methods(), which are shown first.

Any help or advice would be greatly appreciated. Thank you.

public class CustomUDAF extends UserDefinedAggregateFunction {

    @Override
    public DataType dataType() {
        // TODO: Is this the correct way to return 2 arrays?
        return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
            .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
    }

    @Override
    public Object evaluate(Row buffer) {
        // Data conversion
        List<Long> longList = new ArrayList<Long>(buffer.getList(0));
        List<Double> dataList = new ArrayList<Double>(buffer.getList(1));

        // Processing of data (omitted)

        // TODO: How to get data into format needed to return 2 arrays?
        return dataList;
    }

    @Override
    public StructType inputSchema() {
        return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
    }

    @Override
    public StructType bufferSchema() {
        return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
            .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, new ArrayList<Long>());
        buffer.update(1, new ArrayList<Double>());
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row row) {
        ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
        longList.add(row.getLong(0));

        ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
        dataList.add(row.getDouble(1));

        buffer.update(0, longList);
        buffer.update(1, dataList);
    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
        longList.addAll(buffer2.getList(0));

        ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
        dataList.addAll(buffer2.getList(1));

        buffer1.update(0, longList);
        buffer1.update(1, dataList);
    }

    @Override
    public boolean deterministic() {
        return true;
    }
}

Update: Based on the answer by zero323 I was able to return two arrays using:

return new Tuple2<>(longArray, dataArray);

Getting the data out of this was a bit of a struggle but involved deconstructing the DataFrame to Java Lists and then building it back to a DataFrame.


回答1:


As far as I can tell returning a tuple should be just enough. In Scala:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}

object DummyUDAF extends UserDefinedAggregateFunction {
  def inputSchema = new StructType().add("x", StringType)
  def bufferSchema = new StructType()
    .add("buff", ArrayType(LongType))
    .add("buff2", ArrayType(DoubleType))
  def dataType = new StructType()
    .add("xs", ArrayType(LongType))
    .add("ys", ArrayType(DoubleType))
  def deterministic = true 
  def initialize(buffer: MutableAggregationBuffer) = {}
  def update(buffer: MutableAggregationBuffer, input: Row) = {}
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {}
  def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0))
}

val df =  sc.parallelize(Seq(("a", 1), ("b", 2))).toDF("k", "v")
df.select(DummyUDAF($"k")).show(1, false)

// +---------------------------------------------------+
// |(DummyUDAF$(k),mode=Complete,isDistinct=false)     |
// +---------------------------------------------------+
// |[WrappedArray(1, 2, 3),WrappedArray(1.0, 2.0, 3.0)]|
// +---------------------------------------------------+


来源:https://stackoverflow.com/questions/33939642/returning-multiple-arrays-from-user-defined-aggregate-function-udaf-in-apache

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!