背景:有时候我们需要定义一个外部数据源,然后用spark sql的方式来处理。这样的好处有2点:
(1)定义了外部数据源后,用起来很简洁,软件架构清晰,通过sql方式直接使用。
(2)容易分层分模块,一层层往上搭建,容易屏蔽实现细节。
这时候就需要用到定义外部数据源的方式,spark中使用起来也很简单的,所谓会者不难。
首先指定个package名,所有的类在统一的package下。比如com.example.hou。
然后定义两个东西,一个是DefaultSource,一个是BaseRelation with TableScan的子类。
DefaultSource的代码很简单,直接看代码不解释:
package com.example.hou
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
class DefaultSource extends CreatableRelationProvider with SchemaRelationProvider{
def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): BaseRelation = {
val path = parameters.get("path")
path match {
case Some(x) => new TextDataSourceRelation(sqlContext,x,schema)
case _ => throw new IllegalArgumentException("path is required...")
}
}
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
createRelation(sqlContext,parameters,null)
}
}
TextDataSourceRelation的源码:
package com.example.hou
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.TableScan
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
class TextDataSourceRelation (override val sqlContext: SQLContext,path:String,userSchema: StructType) extends BaseRelation with TableScan with Logging{
//如果传进来的schema不为空,就用传进来的schema,否则就用自定义的schema
override def schema: StructType = {
if(userSchema != null){
userSchema
}else{
StructType(
StructField("id",LongType,false) ::
StructField("name",StringType,false) ::
StructField("gender",StringType,false) ::
StructField("salary",LongType,false) ::
StructField("comm",LongType,false) :: Nil
)
}
}
//把数据读进来,读进来之后把它转换成 RDD[Row]
override def buildScan(): RDD[Row] = {
logWarning("this is ruozedata buildScan....")
//读取数据,变成为RDD
//wholeTextFiles会把文件名读进来,可以通过map(_._2)把文件名去掉,第一位是文件名,第二位是内容
val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)
//拿到schema
val schemaField = schema.fields
//rdd.collect().foreach(println)
//rdd + schemaField 把rdd和schemaField解析出来拼起来
val rows = rdd.map(fileContent => {
//拿到每一行的数据
val lines = fileContent.split("\n")
//每一行数据按照逗号分隔,分隔之后去空格,然后转成一个seq集合
val data = lines.filter(line=>{!line.trim().contains("//")}).map(_.split(",").map(_.trim)).toSeq
//zipWithIndex
val result = data.map(x => x.zipWithIndex.map {
case (value, index) => {
val columnName = schemaField(index).name
//castTo里面有两个参数,第一个参数需要给个判断,如果是字段是性别,里面再进行判断再转换一下,如果不是性别就直接用这个字段
Utils.castTo(if(columnName.equalsIgnoreCase("gender")){
if(value == "0"){
"man"
}else if(value == "1"){
"woman"
} else{
"unknown"
}
}else{
value
},schemaField(index).dataType)
}
})
result.map(x => Row.fromSeq(x))
})
rows.flatMap(x => x)
}
}
最后一句就是在Main方法中使用:
package com.example.hou
import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
object TestApp {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("TextApp")
.master("local[2]")
.getOrCreate()
//定义Schema
val schema = StructType(
StructField("id", LongType, false) ::
StructField("name", StringType, false) ::
StructField("gender", StringType, false) ::
StructField("salary", LongType, false) ::
StructField("comm", LongType, false) :: Nil)
//只要写到包名就可以了...example.hou,不用这样写...example.hou.DefaultSource
val df = spark.read.format("com.example.hou")
.option("path", "C://code//data.txt").schema(schema).load()
df.show()
df.createOrReplaceTempView("test")
spark.sql("select name,salary from test").show()
println("Application Ended...")
spark.stop()
}
}
数据类型转换:
package com.example.hou
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.StringType
object Utils {
def castTo(value:String,dataType:DataType) ={
dataType match {
case _:LongType =>value.toLong
case _:StringType => value
}
}
}