spark scala dataframe merge multiple dataframes

…衆ロ難τιáo~ 提交于 2020-01-13 07:15:08

问题


I have three files coming in,

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ab|  ac|
## |  2| bb|  bc|  bd|
## +---+----+----+---+

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ab|  ad|
## |  2| bb|  bb|  bd|
## +---+----+----+---+

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ac|  ad|
## |  2| bb|  bc|  bd|
## +---+----+----+---+

I need to compare the first two files (which I'm reading as dataframe) and identify only the changes and then merge with the third file, so my output should be,

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ac|  ad|
## |  2| bb|  bb|  bd|
## +---+----+----+---+

How to pick only the changed columns? and update another dataframe?


回答1:


I cant comment yet, so I will try to solve this issue. It may need to be still amended. From what I can tell, you are looking for the last unique change. So Val1 has {ab -> ab -> ac, bc -> bb -> bc} so the end result is {ac, bb} because the last file has bc which was in the first file and thus not unique. If this is the case then the best way to deal with this is create a set and take the last value from the set. I will use a udf to get this done

So from your example:

val df1: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ac"),(2,"bb","bc","bd"))).toDF("pk1","pk2","val1","val2")
val df2: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ad"),(2,"bb","bb","bd"))).toDF("pk1","pk2","val1","val2")
val df3: DataFrame = sparkContext.parallelize(Seq((1,"aa","ac","ad"),(2,"bb","bc","bd"))).toDF("pk1","pk2","val1","val2") 

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.UserDefinedFunction
import sqlContext.implicits._

def getChange: UserDefinedFunction = 
    udf((a: String, b: String, c: String) => Set(a,b,c).last)

df1
.join(df2,df1("pk1")===df2("pk1") && df1("pk2")===df2("pk2"), "inner")
.join(df3,df1("pk1")===df3("pk1") && df1("pk2")===df3("pk2"), "inner")
.select(df1("pk1"),df1("pk2"),
  df1("val1").as("df1Val1"),df2("val1").as("df2Val1"),df3("val1").as("df3Val1"),
  df1("val2").as("df1Val2"),df2("val2").as("df2Val2"),df3("val2").as("df3Val2"))
  .withColumn("val1",getChange($"df1Val1",$"df2Val1",$"df3Val1"))
  .withColumn("val2",getChange($"df1Val2",$"df2Val2",$"df3Val2"))
  .select($"pk1",$"pk2",$"val1",$"val2")
  .orderBy($"pk1")
.show(false)

This yields:

+---+---+----+----+
|pk1|pk2|val1|val2|
+---+---+----+----+
|1  |aa |ac  |ad  |
|2  |bb |bb  |bd  |
+---+---+----+----+

Obviously if you use more columns or more dataframes then this will become a bit more cumbersome to write out, but this should do the trick for your example

Edit:
This is used to add more columns to the mix. As I said obove it is a bit more cumbersome. This will iterate through each column until none are left.

require(df1.columns.sameElements(df2.columns) && df1.columns.sameElements(df3.columns),"DF Columns do not match") //this is a check so may not be needed

val cols: Array[String] = df1.columns

def getChange: UserDefinedFunction = udf((a: String, b: String, c: String) => Set(a,b,c).last)

def createFrame(cols: Array[String], df1: DataFrame, df2: DataFrame, df3:DataFrame): scala.collection.mutable.ListBuffer[DataFrame] = {

val list: scala.collection.mutable.ListBuffer[DataFrame] = new scala.collection.mutable.ListBuffer[DataFrame]()
val keys = cols.slice(0,2) //get the keys
val columns = cols.slice(2, cols.length).toSeq //get the columns to use

  def helper(columns: Seq[String]): scala.collection.mutable.ListBuffer[DataFrame] = {
    if(columns.isEmpty) list
    else {
      list += df1
        .join(df2, df1.col(keys(0)) === df2.col(keys(0)) && df1.col(keys(1)) === df2.col(keys(1)), "inner")
        .join(df3, df1.col(keys(0)) === df3.col(keys(0)) && df1.col(keys(1)) === df3.col(keys(1)), "inner")
        .select(df1.col(keys(0)), df1.col(keys(1)),
        getChange(df1.col(columns.head), df2.col(columns.head), df3.col(columns.head)).as(columns.head))

      helper(columns.tail) //use tail recursion
  }
}
  helper(columns)
}

val list: scala.collection.mutable.ListBuffer[DataFrame] = createFrame(cols, df1, df2, df3)

list.reduce((a,b) =>
  a
    .join(b,a(cols.head)===b(cols.head) && a(cols(1))===b(cols(1)),"inner")
    .drop(b(cols.head))
    .drop(b(cols(1))))
.select(cols.head, cols.tail: _*)
.orderBy(cols.head)
.show

An example with 3 value columns then passing these into the code above:

val df1: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ac","ad"),(2,"bb","bc","bd","bc"))).toDF("pk1","pk2","val1","val2","val3")
val df2: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ad","ae"),(2,"bb","bb","bd","bf"))).toDF("pk1","pk2","val1","val2","val3")
val df3: DataFrame = sparkContext.parallelize(Seq((1,"aa","ac","ad","ae"),(2,"bb","bc","bd","bg"))).toDF("pk1","pk2","val1","val2","val3")

yields the following dataframes:

Running the code above yields:

//output
+---+---+----+----+----+
|pk1|pk2|val1|val2|val3|
+---+---+----+----+----+
|  1| aa|  ac|  ad|  ae|
|  2| bb|  bb|  bd|  bg|
+---+---+----+----+----+

There may be a more efficient way to do this as well, but this was off the top of my head.

Edit2

To do this with any amount of keys you can do the following. You will need to define the number of keys when you start. This can also probably be cleaned up as well. Ive got this to work with 4/5 keys, but you should run some tests as well, but it should work:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction

val df1: DataFrame = sparkContext.parallelize(Seq((1,"aa","c","d","ab","ac","ad"),(2,"bb","d","e","bc","bd","bc"))).toDF("pk1","pk2","pk3","pk4","val1","val2","val3")
val df2: DataFrame = sparkContext.parallelize(Seq((1,"aa","c","d","ab","ad","ae"),(2,"bb","d","e","bb","bd","bf"))).toDF("pk1","pk2","pk3","pk4","val1","val2","val3")
val df3: DataFrame = sparkContext.parallelize(Seq((1,"aa","c","d","ac","ad","ae"),(2,"bb","d","e","bc","bd","bg"))).toDF("pk1","pk2","pk3","pk4","val1","val2","val3")

require(df1.columns.sameElements(df2.columns) && df1.columns.sameElements(df3.columns),"DF Columns do not match")

val cols: Array[String] = df1.columns

def getChange: UserDefinedFunction = udf((a: String, b: String, c: String) => Set(a,b,c).last)

def createFrame(cols: Array[String], df1: DataFrame, df2: DataFrame, df3:DataFrame): scala.collection.mutable.ListBuffer[DataFrame] = {

val list: scala.collection.mutable.ListBuffer[DataFrame] = new scala.collection.mutable.ListBuffer[DataFrame]()
val keys = cols.slice(0,4)//get the keys
val columns = cols.slice(4, cols.length).toSeq //get the columns to use

def helper(columns: Seq[String]): scala.collection.mutable.ListBuffer[DataFrame] = {

  if(columns.isEmpty) list
  else {
    list += df1
      .join(df2, Seq(keys :_*), "inner")
      .join(df3, Seq(keys :_*), "inner")
      .withColumn(columns.head + "Out", getChange(df1.col(columns.head), df2.col(columns.head), df3.col(columns.head)))
      .select(col(columns.head + "Out").as(columns.head) +: keys.map(x => df1.col(x)) : _*)

    helper(columns.tail)
  }
}

helper(columns)
}

val list: scala.collection.mutable.ListBuffer[DataFrame] = createFrame(cols, df1, df2, df3)
list.foreach(a => a.show(false))
val keys=cols.slice(0,4)

list.reduce((a,b) =>
  a.alias("a").join(b.alias("b"),Seq(keys :_*),"inner")
  .select("a.*","b." + b.columns.head))
  .orderBy(cols.head)
  .show(false)

This yields:

+---+---+---+---+----+----+----+
|pk1|pk2|pk3|pk4|val1|val2|val3|
+---+---+---+---+----+----+----+
|1  |aa |c  |d  |ac  |ad  |ae  |
|2  |bb |d  |e  |bb  |bd  |bg  |
+---+---+---+---+----+----+----+



回答2:


I can also do this by creating the dataframe as a temp view and then do select case statement. Like this,

df1.createTempView("df1")
df2.createTempView("df2")
df3.createTempView("df3")

select case when df1.val1=df2.val1 and df1.val1<>df3.val1 then df3.val1 end

This is much faster.



来源:https://stackoverflow.com/questions/44663746/spark-scala-dataframe-merge-multiple-dataframes

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!