PySpark - Get indices of duplicate rows

倖福魔咒の 提交于 2020-03-23 07:24:17

问题


Let's say I have a PySpark data frame, like so:

+--+--+--+--+
|a |b |c |d |
+--+--+--+--+
|1 |0 |1 |2 |
|0 |2 |0 |1 |
|1 |0 |1 |2 |
|0 |4 |3 |1 |
+--+--+--+--+

How can I create a column marking all of the duplicate rows, like so:

+--+--+--+--+--+
|a |b |c |d |e |
+--+--+--+--+--+
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|1 |0 |1 |2 |1 |
|0 |4 |3 |1 |0 |
+--+--+--+--+--+

I attempted it using the groupBy and aggregate functions to no avail.


回答1:


Just to expand on my comment:

You can group by all of the columns and use pyspark.sql.functions.count() to determine if a column is duplicated:

import pyspark.sql.functions as f
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")).show()
#+---+---+---+---+---+
#|  a|  b|  c|  d|  e|
#+---+---+---+---+---+
#|  1|  0|  1|  2|  1|
#|  0|  2|  0|  1|  0|
#|  0|  4|  3|  1|  0|
#+---+---+---+---+---+

Here we use count("*") > 1 as the aggregate function, and cast the result to an int. The groupBy() will have the consequence of dropping the duplicate rows. Depending on your needs, this may be sufficient.

However, if you'd like to keep all of the rows, you can use a Window function like shown in the other answers OR you can use a join():

df.join(
    df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")),
    on=df.columns,
    how="inner"
).show()
#+---+---+---+---+---+
#|  a|  b|  c|  d|  e|
#+---+---+---+---+---+
#|  1|  0|  1|  2|  1|
#|  1|  0|  1|  2|  1|
#|  0|  2|  0|  1|  0|
#|  0|  4|  3|  1|  0|
#+---+---+---+---+---+

Here we inner join the original dataframe with the one that is the result of the groupBy() above on all of the columns.




回答2:


Define a window function to check whether the count of rows when grouped by all columns is greater than 1. If yes, its a duplicate (1) else not duplicate (0)

allColumns = df.columns
import sys
from pyspark.sql import functions as f
from pyspark.sql import window as w
windowSpec = w.Window.partitionBy(allColumns).rowsBetween(-sys.maxint, sys.maxint)

df.withColumn('e', f.when(f.count(f.col('d')).over(windowSpec) > 1, f.lit(1)).otherwise(f.lit(0))).show(truncate=False) 

which should give you

+---+---+---+---+---+
|a  |b  |c  |d  |e  |
+---+---+---+---+---+
|1  |0  |1  |2  |1  |
|1  |0  |1  |2  |1  |
|0  |2  |0  |1  |0  |
|0  |4  |3  |1  |0  |
+---+---+---+---+---+

I hope the answer is helpful

Updated

As @pault commented, you can eliminate when, col and lit by casting the boolean to integer:

df.withColumn('e', (f.count('*').over(windowSpec) > 1).cast('int')).show(truncate=False)



回答3:


Partition your dataframe with all the columns and than apply dense_rank.

import sys
from pyspark.sql.functions import dense_rank
from pyspark.sql import window as w

df.withColumn('e', dense_rank().over(w.Window.partitionBy(df.columns))).show()


来源:https://stackoverflow.com/questions/50865803/pyspark-get-indices-of-duplicate-rows

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!