DataFrame-ified zipWithIndex

后端 未结 8 1455
悲哀的现实
悲哀的现实 2020-11-27 04:23

I am trying to solve the age-old problem of adding a sequence number to a data set. I am working with DataFrames, and there appears to be no DataFrame equivalent to RD

8条回答
  •  挽巷
    挽巷 (楼主)
    2020-11-27 04:42

    @Evgeny , your solution is interesting. Notice that there is a bug when you have empty partitions (the array is missing these partition indexes, at least this is happening to me with spark 1.6), so I converted the array into a Map(partitionId -> offsets).

    Additionnally, I took out the sources of monotonically_increasing_id to have "inc_id" starting from 0 in each partition.

    Here is an updated version:

    import org.apache.spark.sql.catalyst.expressions.LeafExpression
    import org.apache.spark.sql.catalyst.InternalRow
    import org.apache.spark.sql.types.LongType
    import org.apache.spark.sql.catalyst.expressions.Nondeterministic
    import org.apache.spark.sql.catalyst.expressions.codegen.GeneratedExpressionCode
    import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenContext
    import org.apache.spark.sql.types.DataType
    import org.apache.spark.sql.DataFrame
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.Column
    import org.apache.spark.sql.expressions.Window
    
    case class PartitionMonotonicallyIncreasingID() extends LeafExpression with Nondeterministic {
    
      /**
       * From org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID
       *
       * Record ID within each partition. By being transient, count's value is reset to 0 every time
       * we serialize and deserialize and initialize it.
       */
      @transient private[this] var count: Long = _
    
      override protected def initInternal(): Unit = {
        count = 1L // notice this starts at 1, not 0 as in org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID
      }
    
      override def nullable: Boolean = false
    
      override def dataType: DataType = LongType
    
      override protected def evalInternal(input: InternalRow): Long = {
        val currentCount = count
        count += 1
        currentCount
      }
    
      override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
        val countTerm = ctx.freshName("count")
        ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 1L;")
        ev.isNull = "false"
        s"""
          final ${ctx.javaType(dataType)} ${ev.value} = $countTerm;
          $countTerm++;
        """
      }
    }
    
    object DataframeUtils {
      def zipWithIndex(df: DataFrame, offset: Long = 0, indexName: String = "index") = {
        // from https://stackoverflow.com/questions/30304810/dataframe-ified-zipwithindex)
        val dfWithPartitionId = df.withColumn("partition_id", spark_partition_id()).withColumn("inc_id", new Column(PartitionMonotonicallyIncreasingID()))
    
        // collect each partition size, create the offset pages
        val partitionOffsets: Map[Int, Long] = dfWithPartitionId
          .groupBy("partition_id")
          .agg(max("inc_id") as "cnt") // in each partition, count(inc_id) is equal to max(inc_id) (I don't know which one would be faster)
          .select(col("partition_id"), sum("cnt").over(Window.orderBy("partition_id")) - col("cnt") + lit(offset) as "cnt")
          .collect()
          .map(r => (r.getInt(0) -> r.getLong(1)))
          .toMap
    
        def partition_offset(partitionId: Int): Long = partitionOffsets(partitionId)
        val partition_offset_udf = udf(partition_offset _)
        // and re-number the index
        dfWithPartitionId
          .withColumn("partition_offset", partition_offset_udf(col("partition_id")))
          .withColumn(indexName, col("partition_offset") + col("inc_id"))
          .drop("partition_id")
          .drop("partition_offset")
          .drop("inc_id")
      }
    }
    

提交回复
热议问题