spark scala - Group by Array column [duplicate]

人走茶凉 提交于 2020-03-03 09:23:27

问题


I am very new to spark scala. Appreciate your help.. I have a dataframe

val df = Seq(
  ("a", "a1", Array("x1","x2")), 
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")

I am looking for a way to group it by k1 and k3 and collect k2 in an array. However, k3 is an array and I need to apply contains (rather than exact match) for the grouping. In other words, I am looking for a result something like this

k1   k3       k2                count
a   (x1,x2)   (a1,b1,c1,e1)     4
a    (x3)      (d1)             1
c    (x2)      (c3)             1

Can somebody advise how to achieve this?

Thanks in advance!


回答1:


I would suggest you to group by k1 column, collect list of structs of k2 and k3, pass the collected list to a udf function for counting when an array in k3 is contained in another array of k3 and adding elements of k2.

Then you can use explode and select expressions to get the desired output

Following is the complete working solution

val df = Seq(
  ("a", "a1", Array("x1","x2")),
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
  ).toDF("k1", "k2", "k3")

import org.apache.spark.sql.functions._
def containsGoupingUdf = udf((arr: Seq[Row]) => {
  val firstStruct =  arr.head
  val tailStructs =  arr.tail
  var result = Array((collection.mutable.Set(firstStruct.getAs[String]("k2")), firstStruct.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
  for(str <- tailStructs){
    var added = false
    for((res, index) <- result.zipWithIndex) {
      if (str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").exists(res._2) || res._2.exists(x => str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").contains(x))) {
        result(index) = (res._1 + str.getAs[String]("k2"), res._2 ++ str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, res._3 + 1)
        added = true
      }
    }
    if(!added){
      result = result ++ Array((collection.mutable.Set(str.getAs[String]("k2")), str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
    }
  }
  result.map(tuple => (tuple._1.toArray, tuple._2.toArray, tuple._3))
})

df.groupBy("k1").agg(containsGoupingUdf(collect_list(struct(col("k2"), col("k3")))).as("aggregated"))
    .select(col("k1"), explode(col("aggregated")).as("aggregated"))
    .select(col("k1"), col("aggregated._2").as("k3"), col("aggregated._1").as("k2"), col("aggregated._3").as("count"))
  .show(false)

which should give you

+---+--------+----------------+-----+
|k1 |k3      |k2              |count|
+---+--------+----------------+-----+
|c  |[x2]    |[c3]            |1    |
|a  |[x1, x2]|[b1, e1, c1, a1]|4    |
|a  |[x3]    |[d1]            |1    |
+---+--------+----------------+-----+ 

I hope the answer is helpful and you can modify it according to your needs.



来源:https://stackoverflow.com/questions/50672206/spark-scala-group-by-array-column

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!