Adding a nested column to Spark DataFrame

后端 未结 1 1418
梦如初夏
梦如初夏 2020-12-15 10:02

How can I add or replace fields to a struct on any nested level?

This input:

val rdd = sc.parallelize(Seq(
  \"\"\"{\"a\": {\"xX\": 1,\"XX\": 2},\"b\         


        
相关标签:
1条回答
  • 2020-12-15 10:29

    It might not be as elegant or as efficient as it could be but here is what I came up with:

    object DataFrameUtils {
      private def nullableCol(parentCol: Column, c: Column): Column = {
        when(parentCol.isNotNull, c)
      }
    
      private def nullableCol(c: Column): Column = {
        nullableCol(c, c)
      }
    
      private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = {
        splitted
          .foldRight(newCol) {
            case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName))
          }
      }
    
      private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = {
        colType match {
          case colType: StructType if splitted.nonEmpty => {
            var modifiedFields: Seq[(String, Column)] = colType.fields
              .map(f => {
                var curCol = col.getField(f.name)
                if (f.name == splitted.head) {
                  curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol)
                }
                (f.name, curCol as f.name)
              })
    
            if (!modifiedFields.exists(_._1 == splitted.head)) {
              modifiedFields :+= (splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head)
            }
    
            var modifiedStruct: Column = struct(modifiedFields.map(_._2): _*)
            if (nullable) {
              modifiedStruct = nullableCol(col, modifiedStruct)
            }
            modifiedStruct
          }
          case _  => createNestedStructs(splitted, newCol)
        }
      }
    
      private def addNestedColumn(df: DataFrame, newColName: String, newCol: Column): DataFrame = {
        if (newColName.contains('.')) {
          var splitted = newColName.split('.')
    
          val modifiedOrAdded: (String, Column) = df.schema.fields
            .find(_.name == splitted.head)
            .map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol)))
            .getOrElse {
              (splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head)
            }
    
          df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2)
    
        } else {
          // Top level addition, use spark method as-is
          df.withColumn(newColName, newCol)
        }
      }
    
      implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
        /**
          * Add nested field to DataFrame
          *
          * @param newColName Dot-separated nested field name
          * @param newCol New column value
          */
        def withNestedColumn(newColName: String, newCol: Column): DataFrame = {
          DataFrameUtils.addNestedColumn(df, newColName, newCol)
        }
      }
    }
    

    Feel free to improve on it.

    val data = spark.sparkContext.parallelize(List("""{ "a1": 1, "a3": { "b1": 3, "b2": { "c1": 5, "c2": 6 } } }"""))
    val df: DataFrame = spark.read.json(data)
    
    val df2 = df.withNestedColumn("a3.b2.c3.d1", $"a3.b2")
    

    should produce:

    assertResult("struct<a1:bigint,a3:struct<b1:bigint,b2:struct<c1:bigint,c2:bigint,c3:struct<d1:struct<c1:bigint,c2:bigint>>>>>")(df2.shema.simpleString)
    
    0 讨论(0)
提交回复
热议问题