问题
I'm trying to check which entries in a spark dataframe (column with lists) contain the largest quantity of values from a given list.
The best approach I've came up with is iterating over a dataframe with rdd.foreach()
and comparing a given list to every entry using python's set1.intersection(set2)
.
My question is does spark have any built-in functionality for this so iterating with .foreach
could be avoided?
Thanks for any help!
P.S. my dataframe looks like this:
+-------------+---------------------+
| cardnumber|collect_list(article)|
+-------------+---------------------+
|2310000000855| [12480, 49627, 80...|
|2310000008455| [35531, 22564, 15...|
|2310000011462| [117112, 156087, ...|
+-------------+---------------------+
And I'm trying to find entries with the most intersections in the second column with a given list of articles, e.g [151574, 87239, 117908, 162475, 48599]
回答1:
The only alternative here is udf
, but it won't be much of a difference.
from pyspark.sql.functions import udf, li, col
def intersect(xs):
xs = set(xs)
@udf("array<long>")
def _(ys):
return list(xs.intersection(ys))
return _
It can be applied as:
a_list = [1, 4, 6]
df = spark.createDataFrame([
(1, [3, 4, 8]), (2, [7, 2, 6])
], ("id", "articles"))
df.withColumn("intersect", intersect(a_list)("articles")).show()
# +---+---------+---------+
# | id| articles|intersect|
# +---+---------+---------+
# | 1|[3, 4, 8]| [4]|
# | 2|[7, 2, 6]| [6]|
# +---+---------+---------+
Based on the names, it looks like you use collect_list
so your data looks probably like this:
df_long = spark.createDataFrame([
(1, 3),(1, 4), (1, 8), (2, 7), (2, 7), (2, 6)
], ("id", "articles"))
In that case problem is simpler. Join
lookup = spark.createDataFrame(a_list, "long").toDF("articles")
joined = lookup.join(df_long, ["articles"])
and aggregate the result:
joined.groupBy("id").count().show()
# +---+-----+
# | id|count|
# +---+-----+
# | 1| 1|
# | 2| 1|
# +---+-----+
joined.groupBy("id").agg(collect_list("articles")).show()
# +---+----------------------+
# | id|collect_list(articles)|
# +---+----------------------+
# | 1| [4]|
# | 2| [6]|
# +---+----------------------+
回答2:
You can try the same set operation in dataframe instead of using rdd.foreach:
from pyspark.sql.functions import udf, li, col
my_udf=udf(lambda A,B: list(set(A).intersection(set(B))))
df=df.withColumn('intersect_value', my_udf('A', 'B'))
You can use the len function to get the size of intersect list in the UDF itself and perform the operation you want from this dataframe.
来源:https://stackoverflow.com/questions/48547700/pyspark-compare-single-list-of-integers-to-column-of-lists