问题
I am very new to spark scala. Appreciate your help.. I have a dataframe
val df = Seq(
("a", "a1", Array("x1","x2")),
("a", "b1", Array("x1")),
("a", "c1", Array("x2")),
("c", "c3", Array("x2")),
("a", "d1", Array("x3")),
("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")
I am looking for a way to group it by k1 and k3 and collect k2 in an array. However, k3 is an array and I need to apply contains (rather than exact match) for the grouping. In other words, I am looking for a result something like this
k1 k3 k2 count
a (x1,x2) (a1,b1,c1,e1) 4
a (x3) (d1) 1
c (x2) (c3) 1
Can somebody advise how to achieve this?
Thanks in advance!
回答1:
I would suggest you to group by k1 column, collect list of structs of k2 and k3, pass the collected list to a udf function for counting when an array in k3 is contained in another array of k3 and adding elements of k2.
Then you can use explode
and select
expressions to get the desired output
Following is the complete working solution
val df = Seq(
("a", "a1", Array("x1","x2")),
("a", "b1", Array("x1")),
("a", "c1", Array("x2")),
("c", "c3", Array("x2")),
("a", "d1", Array("x3")),
("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")
import org.apache.spark.sql.functions._
def containsGoupingUdf = udf((arr: Seq[Row]) => {
val firstStruct = arr.head
val tailStructs = arr.tail
var result = Array((collection.mutable.Set(firstStruct.getAs[String]("k2")), firstStruct.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
for(str <- tailStructs){
var added = false
for((res, index) <- result.zipWithIndex) {
if (str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").exists(res._2) || res._2.exists(x => str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").contains(x))) {
result(index) = (res._1 + str.getAs[String]("k2"), res._2 ++ str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, res._3 + 1)
added = true
}
}
if(!added){
result = result ++ Array((collection.mutable.Set(str.getAs[String]("k2")), str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
}
}
result.map(tuple => (tuple._1.toArray, tuple._2.toArray, tuple._3))
})
df.groupBy("k1").agg(containsGoupingUdf(collect_list(struct(col("k2"), col("k3")))).as("aggregated"))
.select(col("k1"), explode(col("aggregated")).as("aggregated"))
.select(col("k1"), col("aggregated._2").as("k3"), col("aggregated._1").as("k2"), col("aggregated._3").as("count"))
.show(false)
which should give you
+---+--------+----------------+-----+
|k1 |k3 |k2 |count|
+---+--------+----------------+-----+
|c |[x2] |[c3] |1 |
|a |[x1, x2]|[b1, e1, c1, a1]|4 |
|a |[x3] |[d1] |1 |
+---+--------+----------------+-----+
I hope the answer is helpful and you can modify it according to your needs.
来源:https://stackoverflow.com/questions/50672206/spark-scala-group-by-array-column