Rename nested field in spark dataframe

后端 未结 3 479
无人及你
无人及你 2020-11-27 16:33

Having a dataframe df in Spark:

 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: stri         


        
相关标签:
3条回答
  • 2020-11-27 16:44

    Python

    It is not possible to modify a single nested field. You have to recreate a whole structure. In this particular case the simplest solution is to use cast.

    First a bunch of imports:

    from collections import namedtuple
    from pyspark.sql.functions import col
    from pyspark.sql.types import (
        ArrayType, LongType, StringType, StructField, StructType)
    

    and example data:

    Record = namedtuple("Record", ["a", "b", "c"])
    
    df = sc.parallelize([([Record("foo", 1, 3)], )]).toDF(["array_field"])
    

    Let's confirm that the schema is the same as in your case:

    df.printSchema()
    
    root
     |-- array_field: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- a: string (nullable = true)
     |    |    |-- b: long (nullable = true)
     |    |    |-- c: long (nullable = true)
    

    You can define a new schema for example as a string:

    str_schema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"
    
    df.select(col("array_field").cast(str_schema)).printSchema()
    
    root
     |-- array_field: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- a_renamed: string (nullable = true)
     |    |    |-- b: long (nullable = true)
     |    |    |-- c: long (nullable = true)
    

    or a DataType:

    struct_schema = ArrayType(StructType([
        StructField("a_renamed", StringType()),
        StructField("b", LongType()),
        StructField("c", LongType())
    ]))
    
     df.select(col("array_field").cast(struct_schema)).printSchema()
    
    root
     |-- array_field: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- a_renamed: string (nullable = true)
     |    |    |-- b: long (nullable = true)
     |    |    |-- c: long (nullable = true)
    

    Scala

    The same techniques can be used in Scala:

    case class Record(a: String, b: Long, c: Long)
    
    val df = Seq(Tuple1(Seq(Record("foo", 1, 3)))).toDF("array_field")
    
    val strSchema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"
    
    df.select($"array_field".cast(strSchema))
    

    or

    import org.apache.spark.sql.types._
    
    val structSchema = ArrayType(StructType(Seq(
        StructField("a_renamed", StringType),
        StructField("b", LongType),
        StructField("c", LongType)
    )))
    
    df.select($"array_field".cast(structSchema))
    

    Possible improvements:

    If you use an expressive data manipulation or JSON processing library it could be easier to dump data types to dict or JSON string and take it from there for example (Python / toolz):

    from toolz.curried import pipe, assoc_in, update_in, map
    from operator import attrgetter
    
    # Update name to "a_updated" if name is "a"
    rename_field = update_in(
        keys=["name"], func=lambda x: "a_updated" if x == "a" else x)
    
    updated_schema = pipe(
       #  Get schema of the field as a dict
       df.schema["array_field"].jsonValue(),
       # Update fields with rename
       update_in(
           keys=["type", "elementType", "fields"],
           func=lambda x: pipe(x, map(rename_field), list)),
       # Load schema from dict
       StructField.fromJson,
       # Get data type
       attrgetter("dataType"))
    
    df.select(col("array_field").cast(updated_schema)).printSchema()
    
    0 讨论(0)
  • 2020-11-27 16:56

    You can recurse over the data frame's schema to create a new schema with the required changes.

    A schema in PySpark is a StructType which holds a list of StructFields and each StructField can hold some primitve type or another StructType.

    This means that we can decide if we want to recurse based on whether the type is a StructType or not.

    Below is an annotated sample implementation that shows you how you can implement the above idea.

    # Some imports
    from pyspark.sql import *
    from copy import copy
    
    # We take a dataframe and return a new one with required changes
    def cleanDataFrame(df: DataFrame) -> DataFrame:
        # Returns a new sanitized field name (this function can be anything really)
        def sanitizeFieldName(s: str) -> str:
            return s.replace("-", "_").replace("&", "_").replace("\"", "_")\
                .replace("[", "_").replace("]", "_").replace(".", "_")
    
        # We call this on all fields to create a copy and to perform any changes we might
        # want to do to the field.
        def sanitizeField(field: StructField) -> StructField:
            field = copy(field)
            field.name = sanitizeFieldName(field.name)
            # We recursively call cleanSchema on all types
            field.dataType = cleanSchema(field.dataType)
            return field
    
        def cleanSchema(dataType: [DataType]) -> [DateType]:
            dataType = copy(dataType)
            # If the type is a StructType we need to recurse otherwise we can return since
            # we've reached the leaf node
            if isinstance(dataType, StructType):
                # We call our sanitizer for all top level fields
                dataType.fields = [sanitizeField(f) for f in dataType.fields]
            elif isinstance(dataType, ArrayType):
                dataType.elementType = cleanSchema(dataType.elementType)
            return dataType
    
        # Now since we have the new schema we can create a new DataFrame by using the old Frame's RDD as data and the new schema as the schema for the data
        return spark.createDataFrame(df.rdd, cleanSchema(df.schema))
    
    0 讨论(0)
  • 2020-11-27 17:03

    I found a much easier way than the one provided by @zero323, along the lines of @MaxPY:

    Pyspark 2.4:

    # Get the schema from the dataframe df
    schema = df.schema
    
    # Override `fields` with a list of new StructField, equals to the previous but for the names
    schema.fields = (list(map(lambda field: 
                              StructField(field.name + "_renamed", field.dataType), schema.fields)))
    
    # Override also `names` with the same mechanism
    schema.names = list(map(lambda name: name + "_renamed", table_schema.names))
    

    Now df.schema will print all the renewed names.

    0 讨论(0)
提交回复
热议问题