How to convert spark SchemaRDD into RDD of my case class?

倾然丶 夕夏残阳落幕 提交于 2019-11-30 03:56:05

The best solution I've come up with that requires the least amount of copy and pasting for new classes is as follows (I'd still like to see another solution though)

First you have to define your case class, and a (partially) reusable factory method

import org.apache.spark.sql.catalyst.expressions

case class MyClass(fooBar: Long, fred: Long)

// Here you want to auto gen these functions using macros or something
object Factories extends java.io.Serializable {
  def longLong[T](fac: (Long, Long) => T)(row: expressions.Row): T = 
    fac(row(0).asInstanceOf[Long], row(1).asInstanceOf[Long])
}

Some boiler plate which will already be available

import scala.reflect.runtime.universe._
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD

The magic

import scala.reflect.ClassTag
import org.apache.spark.sql.SchemaRDD

def camelToUnderscores(name: String) = 
  "[A-Z]".r.replaceAllIn(name, "_" + _.group(0).toLowerCase())

def getCaseMethods[T: TypeTag]: List[String] = typeOf[T].members.sorted.collect {
  case m: MethodSymbol if m.isCaseAccessor => m
}.toList.map(_.toString)

def caseClassToSQLCols[T: TypeTag]: List[String] = 
  getCaseMethods[T].map(_.split(" ")(1)).map(camelToUnderscores)

def schemaRDDToRDD[T: TypeTag: ClassTag](schemaRDD: SchemaRDD, fac: expressions.Row => T) = {
  val tmpName = "tmpTableName" // Maybe should use a random string
  schemaRDD.registerAsTable(tmpName)
  sqlContext.sql("SELECT " + caseClassToSQLCols[T].mkString(", ") + " FROM " + tmpName)
  .map(fac)
}

Example use

val parquetFile = sqlContext.parquetFile(path)

val normalRDD: RDD[MyClass] = 
  schemaRDDToRDD[MyClass](parquetFile, Factories.longLong[MyClass](MyClass.apply))

See also:

http://apache-spark-user-list.1001560.n3.nabble.com/Spark-SQL-Convert-SchemaRDD-back-to-RDD-td9071.html

Though I failed to find any example or documentation by following the JIRA link.

An easy way is to provide your own converter (Row) => CaseClass. This is a bit more manual, but if you know what you are reading it should be quite straightforward.

Here is an example:

import org.apache.spark.sql.SchemaRDD

case class User(data: String, name: String, id: Long)

def sparkSqlToUser(r: Row): Option[User] = {
    r match {
      case Row(time: String, name: String, id: Long) => Some(User(time,name, id))
      case _ => None
    }
}

val parquetData: SchemaRDD = sqlContext.parquetFile("hdfs://localhost/user/data.parquet")

val caseClassRdd: org.apache.spark.rdd.RDD[User] = parquetData.flatMap(sparkSqlToUser)

there is a simple method to convert schema rdd to rdd using pyspark in Spark 1.2.1.

sc = SparkContext()  ## create SparkContext
srdd = sqlContext.sql(sql)
c = srdd.collect()  ## convert rdd to list
rdd = sc.parallelize(c)

there must be similar approach using scala.

Very crufty attempt. Very unconvinced this will have decent performance. Surely there must a macro-based alternative...

import scala.reflect.runtime.universe.typeOf
import scala.reflect.runtime.universe.MethodSymbol
import scala.reflect.runtime.universe.NullaryMethodType
import scala.reflect.runtime.universe.TypeRef
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.NoType
import scala.reflect.runtime.universe.termNames
import scala.reflect.runtime.universe.runtimeMirror

schemaRdd.map(row => RowToCaseClass.rowToCaseClass(row.toSeq, typeOf[X], 0))

object RowToCaseClass {
  // http://dcsobral.blogspot.com/2012/08/json-serialization-with-reflection-in.html
  def rowToCaseClass(record: Seq[_], t: Type, depth: Int): Any = {
    val fields = t.decls.sorted.collect {
      case m: MethodSymbol if m.isCaseAccessor => m
    }
    val values = fields.zipWithIndex.map {
      case (field, i) =>
        field.typeSignature match {
          case NullaryMethodType(sig) if sig =:= typeOf[String] => record(i).asInstanceOf[String]
          case NullaryMethodType(sig) if sig =:= typeOf[Int] => record(i).asInstanceOf[Int]
          case NullaryMethodType(sig) =>
            if (sig.baseType(typeOf[Seq[_]].typeSymbol) != NoType) {
              sig match {
                case TypeRef(_, _, args) =>
                  record(i).asInstanceOf[Seq[Seq[_]]].map {
                    r => rowToCaseClass(r, args(0), depth + 1)
                  }.toSeq
              }
            } else {
              sig match {
                case TypeRef(_, u, _) =>
                  rowToCaseClass(record(i).asInstanceOf[Seq[_]], sig, depth + 1)
              }
            }
        }
    }.asInstanceOf[Seq[Object]]
    val mirror = runtimeMirror(t.getClass.getClassLoader)
    val ctor = t.member(termNames.CONSTRUCTOR).asMethod
    val klass = t.typeSymbol.asClass
    val method = mirror.reflectClass(klass).reflectConstructor(ctor)
    method.apply(values: _*)
  }
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!