How to store custom objects in Dataset?

前端 未结 9 1172
别那么骄傲
别那么骄傲 2020-11-22 01:53

According to Introducing Spark Datasets:

As we look forward to Spark 2.0, we plan some exciting improvements to Datasets, specifically: ... Custom

9条回答
  •  一整个雨季
    2020-11-22 01:56

    You can use UDTRegistration and then Case Classes, Tuples, etc... all work correctly with your User Defined Type!

    Say you want to use a custom Enum:

    trait CustomEnum { def value:String }
    case object Foo extends CustomEnum  { val value = "F" }
    case object Bar extends CustomEnum  { val value = "B" }
    object CustomEnum {
      def fromString(str:String) = Seq(Foo, Bar).find(_.value == str).get
    }
    

    Register it like this:

    // First define a UDT class for it:
    class CustomEnumUDT extends UserDefinedType[CustomEnum] {
      override def sqlType: DataType = org.apache.spark.sql.types.StringType
      override def serialize(obj: CustomEnum): Any = org.apache.spark.unsafe.types.UTF8String.fromString(obj.value)
      // Note that this will be a UTF8String type
      override def deserialize(datum: Any): CustomEnum = CustomEnum.fromString(datum.toString)
      override def userClass: Class[CustomEnum] = classOf[CustomEnum]
    }
    
    // Then Register the UDT Class!
    // NOTE: you have to put this file into the org.apache.spark package!
    UDTRegistration.register(classOf[CustomEnum].getName, classOf[CustomEnumUDT].getName)
    

    Then USE IT!

    case class UsingCustomEnum(id:Int, en:CustomEnum)
    
    val seq = Seq(
      UsingCustomEnum(1, Foo),
      UsingCustomEnum(2, Bar),
      UsingCustomEnum(3, Foo)
    ).toDS()
    seq.filter(_.en == Foo).show()
    println(seq.collect())
    

    Say you want to use a Polymorphic Record:

    trait CustomPoly
    case class FooPoly(id:Int) extends CustomPoly
    case class BarPoly(value:String, secondValue:Long) extends CustomPoly
    

    ... and the use it like this:

    case class UsingPoly(id:Int, poly:CustomPoly)
    
    Seq(
      UsingPoly(1, new FooPoly(1)),
      UsingPoly(2, new BarPoly("Blah", 123)),
      UsingPoly(3, new FooPoly(1))
    ).toDS
    
    polySeq.filter(_.poly match {
      case FooPoly(value) => value == 1
      case _ => false
    }).show()
    

    You can write a custom UDT that encodes everything to bytes (I'm using java serialization here but it's probably better to instrument Spark's Kryo context).

    First define the UDT class:

    class CustomPolyUDT extends UserDefinedType[CustomPoly] {
      val kryo = new Kryo()
    
      override def sqlType: DataType = org.apache.spark.sql.types.BinaryType
      override def serialize(obj: CustomPoly): Any = {
        val bos = new ByteArrayOutputStream()
        val oos = new ObjectOutputStream(bos)
        oos.writeObject(obj)
    
        bos.toByteArray
      }
      override def deserialize(datum: Any): CustomPoly = {
        val bis = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
        val ois = new ObjectInputStream(bis)
        val obj = ois.readObject()
        obj.asInstanceOf[CustomPoly]
      }
    
      override def userClass: Class[CustomPoly] = classOf[CustomPoly]
    }
    

    Then register it:

    // NOTE: The file you do this in has to be inside of the org.apache.spark package!
    UDTRegistration.register(classOf[CustomPoly].getName, classOf[CustomPolyUDT].getName)
    

    Then you can use it!

    // As shown above:
    case class UsingPoly(id:Int, poly:CustomPoly)
    
    Seq(
      UsingPoly(1, new FooPoly(1)),
      UsingPoly(2, new BarPoly("Blah", 123)),
      UsingPoly(3, new FooPoly(1))
    ).toDS
    
    polySeq.filter(_.poly match {
      case FooPoly(value) => value == 1
      case _ => false
    }).show()
    

提交回复
热议问题