Generic iterator over dataframe (Spark/scala)

不羁岁月 提交于 2019-12-04 19:53:05

Well, sufficient solution is below

def f_row(iter: Iterator[Row]): Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r = Row.fromSeq(head.toSeq :+ head.getInt(head.fieldIndex("y")))
    iter.scanLeft(r)((r1, r2) => 
      Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y"))))
  } else iter
}
val encoder = 
  RowEncoder(StructType(df.schema.fields :+ StructField("s", IntegerType, false)))
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

Update

Functions like getInt can be avoided in favor of more generic getAs.

Also, in order to be able to access rows of r1 by name we can generate GenericRowWithSchema which is subclass of Row.

Implicit parameter has been added to f_row so that function can use current schema of the data frame and in the same time it can be used as a parameter of the mapPartitions.

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.encoders.RowEncoder

implicit val schema = StructType(df.schema.fields :+ StructField("result", IntegerType))
implicit val encoder = RowEncoder(schema)

def mul(x1: Int, x2: Int) = x1 * x2;

def f_row(iter: Iterator[Row])(implicit currentSchema : StructType) : Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r =
      new GenericRowWithSchema((head.toSeq :+ (head.getAs("y"))).toArray, currentSchema)

    iter.scanLeft(r)((r1, r2) =>
      new GenericRowWithSchema((r2.toSeq :+ mul(r1.getAs("result"), r2.getAs("y"))).toArray, currentSchema))
  } else iter
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show

Finally, logic can be implemented in a tail recursive manner.

import scala.annotation.tailrec

def f_row(iter: Iterator[Row]) = {
  @tailrec
  def f_row_(iter: Iterator[Row], tmp: Int, result: Iterator[Row]): Iterator[Row] = {
    if (iter.hasNext) {
      val r = iter.next
      f_row_(iter, mul(tmp, r.getAs("y")),
        result ++ Iterator(Row.fromSeq(r.toSeq :+ mul(tmp, r.getAs("y")))))
    } else result
  }
  f_row_(iter, 1, Iterator[Row]())
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!