Pyspark: how to duplicate a row n time in dataframe?

对着背影说爱祢 提交于 2019-12-10 14:22:06

问题


I've got a dataframe like this and I want to duplicate the row n times if the column n is bigger than one:

A   B   n  
1   2   1  
2   9   1  
3   8   2    
4   1   1    
5   3   3 

And transform like this:

A   B   n  
1   2   1  
2   9   1  
3   8   2
3   8   2       
4   1   1    
5   3   3 
5   3   3 
5   3   3 

I think I should use explode, but I don't understand how it works...
Thanks


回答1:


The explode function returns a new row for each element in the given array or map.

One way to exploit this function is to use a udf to create a list of size n for each row. Then explode the resulting array.

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType

df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) 

# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))

# now use explode  
df2.withColumn('n', explode(df2.n)).show()

+---+---+---+ 
| A | B | n | 
+---+---+---+ 
|  1|  2|  1| 
|  2|  9|  1| 
|  3|  8|  2| 
|  3|  8|  2| 
|  4|  1|  1| 
|  5|  3|  3| 
|  5|  3|  3| 
|  5|  3|  3| 
+---+---+---+ 



回答2:


I think the udf answer by @Ahmed is the best way to go, but here is an alternative method, that may be as good or better for small n:

First, collect the maximum value of n over the whole DataFrame:

max_n = df.select(f.max('n').alias('max_n')).first()['max_n']
print(max_n)
#3

Now create an array for each row of length max_n, containing numbers in range(max_n). The output of this intermediate step will result in a DataFrame like:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)])).show()
#+---+---+---+---------+
#|  A|  B|  n|  n_array|
#+---+---+---+---------+
#|  1|  2|  1|[0, 1, 2]|
#|  2|  9|  1|[0, 1, 2]|
#|  3|  8|  2|[0, 1, 2]|
#|  4|  1|  1|[0, 1, 2]|
#|  5|  3|  3|[0, 1, 2]|
#+---+---+---+---------+

Now we explode the n_array column, and filter to keep only the values in the array that are less than n. This will ensure that we have n copies of each row. Finally we drop the exploded column to get the end result:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)]))\
    .select('A', 'B', 'n', f.explode('n_array').alias('col'))\
    .where(f.col('col') < f.col('n'))\
    .drop('col')\
    .show()
#+---+---+---+
#|  A|  B|  n|
#+---+---+---+
#|  1|  2|  1|
#|  2|  9|  1|
#|  3|  8|  2|
#|  3|  8|  2|
#|  4|  1|  1|
#|  5|  3|  3|
#|  5|  3|  3|
#|  5|  3|  3|
#+---+---+---+

However, we are creating a max_n length array for each row- as opposed to just an n length array in the udf solution. It's not immediately clear to me how this will scale vs. udf for large max_n, but I suspect the udf will win out.




回答3:


With Spark 2.4.0+, this is easier with builtin functions: array_repeat + explode:

from pyspark.sql.functions import expr

df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)], ["A", "B", "n"])

new_df = df.withColumn('n', expr('explode(array_repeat(n,int(n)))'))

>>> new_df.show()
+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
|  5|  3|  3|
|  5|  3|  3|
+---+---+---+


来源:https://stackoverflow.com/questions/50624745/pyspark-how-to-duplicate-a-row-n-time-in-dataframe

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!