How can I pass extra parameters to UDFs in Spark SQL?

…衆ロ難τιáo~ 提交于 2019-12-17 04:01:28

问题


I want to parse the date columns in a DataFrame, and for each date column, the resolution for the date may change (i.e. 2011/01/10 => 2011 /01 if the resolution is set to "Month").

I wrote the following code:

def convertDataFrame(dataframe: DataFrame, schema : Array[FieldDataType], resolution: Array[DateResolutionType]) : DataFrame =
{
  import org.apache.spark.sql.functions._
  val convertDateFunc = udf{(x:String, resolution: DateResolutionType) => SparkDateTimeConverter.convertDate(x, resolution)}
  val convertDateTimeFunc = udf{(x:String, resolution: DateResolutionType) => SparkDateTimeConverter.convertDateTime(x, resolution)}

  val allColNames = dataframe.columns
  val allCols = allColNames.map(name => dataframe.col(name))

  val mappedCols =
  {
    for(i <- allCols.indices) yield
    {
      schema(i) match
      {
        case FieldDataType.Date => convertDateFunc(allCols(i), resolution(i)))
        case FieldDataType.DateTime => convertDateTimeFunc(allCols(i), resolution(i))
        case _ => allCols(i)
      }
    }
  }

  dataframe.select(mappedCols:_*)

}}

However it doesn't work. It seems that I can only pass Columns to UDFs. And I wonder if it will be very slow if I convert the DataFrame to RDD and apply the function on each row.

Does anyone know the correct solution? Thank you!


回答1:


Just use a little bit of currying:

def convertDateFunc(resolution: DateResolutionType) = udf((x:String) => 
  SparkDateTimeConverter.convertDate(x, resolution))

and use it as follows:

case FieldDataType.Date => convertDateFunc(resolution(i))(allCols(i))

On a side note you should take a look at sql.functions.trunc and sql.functions.date_format. These should at least part of the job without using UDFs at all.

Note:

In Spark 2.2 or later you can use typedLit function:

import org.apache.spark.sql.functions.typedLit

which support a wider range of literals like Seq or Map.




回答2:


You can create a literal Column to pass to a udf using the lit(...) function defined in org.apache.spark.sql.functions

For example:

val takeRight = udf((s: String, i: Int) => s.takeRight(i))
df.select(takeRight($"stringCol", lit(1)))


来源:https://stackoverflow.com/questions/35546576/how-can-i-pass-extra-parameters-to-udfs-in-spark-sql

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!