How to filter one spark dataframe against another dataframe

匿名 (未验证) 提交于 2019-12-03 08:46:08

问题:

I'm trying to filter one dataframe against another:

scala> val df1 = sc.parallelize((1 to 100).map(a=>(s"user $a", a*0.123, a))).toDF("name", "score", "user_id") scala> val df2 = sc.parallelize(List(2,3,4,5,6)).toDF("valid_id") 

Now I want to filter df1 and get back a dataframe that contains all the rows in df1 where user_id is in df2("valid_id"). In other words, I want all the rows in df1 where the user_id is either 2,3,4,5 or 6

scala> df1.select("user_id").filter($"user_id" in df2("valid_id")) warning: there were 1 deprecation warning(s); re-run with -deprecation for details org.apache.spark.sql.AnalysisException: resolved attribute(s) valid_id#20 missing from user_id#18 in operator !Filter user_id#18 IN (valid_id#20);   

On the other hand when I try to do a filter against a function, everything looks great:

scala> df1.select("user_id").filter(($"user_id" % 2) === 0) res1: org.apache.spark.sql.DataFrame = [user_id: int] 

Why am I getting this error? Is there something wrong with my syntax?

following comment I have tried to do a left outer join:

scala> df1.show +-------+------------------+-------+ |   name|             score|user_id| +-------+------------------+-------+ | user 1|             0.123|      1| | user 2|             0.246|      2| | user 3|             0.369|      3| | user 4|             0.492|      4| | user 5|             0.615|      5| | user 6|             0.738|      6| | user 7|             0.861|      7| | user 8|             0.984|      8| | user 9|             1.107|      9| |user 10|              1.23|     10| |user 11|             1.353|     11| |user 12|             1.476|     12| |user 13|             1.599|     13| |user 14|             1.722|     14| |user 15|             1.845|     15| |user 16|             1.968|     16| |user 17|             2.091|     17| |user 18|             2.214|     18| |user 19|2.3369999999999997|     19| |user 20|              2.46|     20| +-------+------------------+-------+ only showing top 20 rows  scala> df2.show +--------+ |valid_id| +--------+ |       2| |       3| |       4| |       5| |       6| +--------+  scala> df1.join(df2, df1("user_id") === df2("valid_id")) res6: org.apache.spark.sql.DataFrame = [name: string, score: double, user_id: int, valid_id: int] scala> res6.collect res7: Array[org.apache.spark.sql.Row] = Array()  scala> df1.join(df2, df1("user_id") === df2("valid_id"), "left_outer") res8: org.apache.spark.sql.DataFrame = [name: string, score: double, user_id: int, valid_id: int] scala> res8.count res9: Long = 0 

I'm running spark 1.5.0 with scala 2.10.5

回答1:

You want a (regular) inner join, not an outer join :)

df1.join(df2, df1("user_id") === df2("valid_id")) 


易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!