How to filter rows for a specific aggregate with spark sql?

隐身守侯 提交于 2019-12-07 02:18:12

问题


Normally all rows in a group are passed to an aggregate function. I would like to filter rows using a condition so that only some rows within a group are passed to an aggregate function. Such operation is possible with PostgreSQL. I would like to do the same thing with Spark SQL DataFrame (Spark 2.0.0).

The code could probably look like this:

val df = ... // some data frame
df.groupBy("A").agg(
  max("B").where("B").less(10), // there is no such method as `where` :(
  max("C").where("C").less(5)
)

So for a data frame like this:

| A | B | C |
|  1| 14|  4|
|  1|  9|  3|
|  2|  5|  6|

The result would be:

|A|max(B)|max(C)|
|1|    9|      4|
|2|    5|   null|

Is it possible with Spark SQL?

Note that in general any other aggregate function than max could be used and there could be multiple aggregates over the same column with arbitrary filtering conditions.


回答1:


val df = Seq(
    (1,14,4),
    (1,9,3),
    (2,5,6)
  ).toDF("a","b","c")

val aggregatedDF = df.groupBy("a")
  .agg(
    max(when($"b" < 10, $"b")).as("MaxB"),
    max(when($"c" < 5, $"c")).as("MaxC")
  )

aggregatedDF.show



回答2:


    >>> df = sc.parallelize([[1,14,1],[1,9,3],[2,5,6]]).map(lambda t: Row(a=int(t[0]),b=int(t[1]),c=int(t[2]))).toDF()
    >>> df.registerTempTable('t')
   >>> res = sqlContext.sql("select a,max(case when b<10 then b else null end) mb,max(case when c<5 then c else null end) mc from t group by a")

    +---+---+----+
    |  a| mb|  mc|
    +---+---+----+
    |  1|  9|   3|
    |  2|  5|null|
    +---+---+----+

You can use sql (I believe you do the same thing in Postgres?)




回答3:


df.groupBy("name","age","id").agg(functions.max("age").$less(20),functions.max("id").$less("30")).show();

Sample Data:

name    age id
abc     23  1001
cde     24  1002
efg     22  1003
ghi     21  1004
ijk     20  1005
klm     19  1006
mno     18  1007
pqr     18  1008
rst     26  1009
tuv     27  1010
pqr     18  1012
rst     28  1013
tuv     29  1011
abc     24  1015

Output:

+----+---+----+---------------+--------------+
|name|age|  id|(max(age) < 20)|(max(id) < 30)|
+----+---+----+---------------+--------------+
| rst| 26|1009|          false|          true|
| abc| 23|1001|          false|          true|
| ijk| 20|1005|          false|          true|
| tuv| 29|1011|          false|          true|
| efg| 22|1003|          false|          true|
| mno| 18|1007|           true|          true|
| tuv| 27|1010|          false|          true|
| klm| 19|1006|           true|          true|
| cde| 24|1002|          false|          true|
| pqr| 18|1008|           true|          true|
| abc| 24|1015|          false|          true|
| ghi| 21|1004|          false|          true|
| rst| 28|1013|          false|          true|
| pqr| 18|1012|           true|          true|
+----+---+----+---------------+--------------+


来源:https://stackoverflow.com/questions/39713355/how-to-filter-rows-for-a-specific-aggregate-with-spark-sql

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!