Pass array as an UDF parameter in Spark SQL

匿名 (未验证) 提交于 2019-12-03 02:30:02

问题:

I'm trying to transform a dataframe via a function that takes an array as a parameter. My code looks something like this:

def getCategory(categories:Array[String], input:String): String = {    categories(input.toInt)  }   val myArray = Array("a", "b", "c")   val myCategories =udf(getCategory _ )   val df = sqlContext.parquetFile("myfile.parquet)   val df1 = df.withColumn("newCategory", myCategories(lit(myArray), col("myInput"))  

However, lit doesn't like arrays and this script errors. I tried definining a new partially applied function and then the udf after that :

val newFunc = getCategory(myArray,  _:String)  val myCategories = udf(newFunc)   val df1 = df.withColumn("newCategory", myCategories(col("myInput")))  

This doesn't work either as I get a nullPointer exception and it appears myArray is not being recognized. Any ideas on how I pass an array as a parameter to a function with a dataframe?

On a separate note, any explanation as to why doing something simple like using a function on a dataframe is so complicated (define function, redefine it as UDF, etc, etc)?

回答1:

Most likely not the prettiest solution but you can try something like this:

def getCategory(categories: Array[String]) = {     udf((input:String) => categories(input.toInt)) }  df.withColumn("newCategory", getCategory(myArray)(col("myInput"))) 

You could also try an array of literals:

val getCategory = udf(    (input:String, categories: Array[String]) => categories(input.toInt))  df.withColumn(   "newCategory", getCategory($"myInput", array(myArray.map(lit(_)): _*))) 

On a side note using Map instead of Array is probably a better idea:

def mapCategory(categories: Map[String, String], default: String) = {     udf((input:String) =>  categories.getOrElse(input, default)) }  val myMap = Map[String, String]("1" -> "a", "2" -> "b", "3" -> "c")  df.withColumn("newCategory", mapCategory(myMap, "foo")(col("myInput"))) 

Since Spark 1.5.0 you can also use an array function:

import org.apache.spark.sql.functions.array  val colArray = array(myArray map(lit  _): _*) myCategories(lit(colArray), col("myInput")) 

See also Spark UDF with varargs



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!