How to return a “Tuple type” in a UDF in PySpark?

后端 未结 3 849
轻奢々
轻奢々 2020-12-24 04:34

All the data types in pyspark.sql.types are:

__all__ = [
    \"DataType\", \"NullType\", \"StringType\", \"BinaryType\", \"BooleanType\", \"DateType\",
    \         


        
3条回答
  •  一整个雨季
    2020-12-24 04:53

    Stackoverflow keeps directing me to this question, so I guess I'll add some info here.

    Returning simple types from UDF:

    from pyspark.sql.types import *
    from pyspark.sql import functions as F
    
    def get_df():
      d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)]
      df = sqlContext.createDataFrame(d, ['x', 'y'])
      return df
    
    df = get_df()
    df.show()
    
    # +---+---+
    # |  x|  y|
    # +---+---+
    # |0.0|0.0|
    # |0.0|3.0|
    # |1.0|6.0|
    # |1.0|9.0|
    # +---+---+
    
    func = udf(lambda x: str(x), StringType())
    df = df.withColumn('y_str', func('y'))
    
    func = udf(lambda x: int(x), IntegerType())
    df = df.withColumn('y_int', func('y'))
    
    df.show()
    
    # +---+---+-----+-----+
    # |  x|  y|y_str|y_int|
    # +---+---+-----+-----+
    # |0.0|0.0|  0.0|    0|
    # |0.0|3.0|  3.0|    3|
    # |1.0|6.0|  6.0|    6|
    # |1.0|9.0|  9.0|    9|
    # +---+---+-----+-----+
    
    df.printSchema()
    
    # root
    #  |-- x: double (nullable = true)
    #  |-- y: double (nullable = true)
    #  |-- y_str: string (nullable = true)
    #  |-- y_int: integer (nullable = true)
    

    When integers are not enough:

    df = get_df()
    
    func = udf(lambda x: [0]*int(x), ArrayType(IntegerType()))
    df = df.withColumn('list', func('y'))
    
    func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
               MapType(FloatType(), StringType()))
    df = df.withColumn('map', func('y'))
    
    df.show()
    # +---+---+--------------------+--------------------+
    # |  x|  y|                list|                 map|
    # +---+---+--------------------+--------------------+
    # |0.0|0.0|                  []|               Map()|
    # |0.0|3.0|           [0, 0, 0]|Map(2.0 -> 2, 0.0...|
    # |1.0|6.0|  [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...|
    # |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...|
    # +---+---+--------------------+--------------------+
    
    df.printSchema()
    # root
    #  |-- x: double (nullable = true)
    #  |-- y: double (nullable = true)
    #  |-- list: array (nullable = true)
    #  |    |-- element: integer (containsNull = true)
    #  |-- map: map (nullable = true)
    #  |    |-- key: float
    #  |    |-- value: string (valueContainsNull = true)
    

    Returning complex datatypes from UDF:

    df = get_df()
    df = df.groupBy('x').agg(F.collect_list('y').alias('y[]'))
    df.show()
    
    # +---+----------+
    # |  x|       y[]|
    # +---+----------+
    # |0.0|[0.0, 3.0]|
    # |1.0|[9.0, 6.0]|
    # +---+----------+
    
    schema = StructType([
        StructField("min", FloatType(), True),
        StructField("size", IntegerType(), True),
        StructField("edges",  ArrayType(FloatType()), True),
        StructField("val_to_index",  MapType(FloatType(), IntegerType()), True)
        # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)]))
    
    ])
    
    def func(values):
      mn = min(values)
      size = len(values)
      lst = sorted(values)[::-1]
      val_to_index = {x: i for i, x in enumerate(values)}
      return (mn, size, lst, val_to_index)
    
    func = udf(func, schema)
    dff = df.select('*', func('y[]').alias('complex_type'))
    dff.show(10, False)
    
    # +---+----------+------------------------------------------------------+
    # |x  |y[]       |complex_type                                          |
    # +---+----------+------------------------------------------------------+
    # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
    # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
    # +---+----------+------------------------------------------------------+
    
    dff.printSchema()
    
    # +---+----------+------------------------------------------------------+
    # |x  |y[]       |complex_type                                          |
    # +---+----------+------------------------------------------------------+
    # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
    # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
    # +---+----------+------------------------------------------------------+
    

    Passing multiple arguments to a UDF:

    df = get_df()
    func = udf(lambda arr: arr[0]*arr[1],FloatType())
    df = df.withColumn('x*y', func(F.array('x', 'y')))
    
        # +---+---+---+
        # |  x|  y|x*y|
        # +---+---+---+
        # |0.0|0.0|0.0|
        # |0.0|3.0|0.0|
        # |1.0|6.0|6.0|
        # |1.0|9.0|9.0|
        # +---+---+---+
    

    The code is purely for demo purposes, all above transformation are available in Spark code and would yield much better performance. As @zero323 in the comment above, UDFs should generally be avoided in pyspark; returning complex types should make you think about simplifying your logic.

提交回复
热议问题