问题
I have the following data in a pyspark dataframe called end_stats_df
:
values start end cat1 cat2
10 1 2 A B
11 1 2 C B
12 1 2 D B
510 1 2 D C
550 1 2 C B
500 1 2 A B
80 1 3 A B
And I want to aggregate it in the following way:
- I want to use the "start" and "end" columns as the aggregate keys
- For each group of rows, I need to do the following:
- Compute the unique number of values in both
cat1
andcat2
for that group. e.g., for the group ofstart
=1 andend
=2, this number would be 4 because there's A, B, C, D. This number will be stored asn
(n=4 in this example). - For the
values
field, for each group I need to sort thevalues
, and then select everyn-1
value, wheren
is the value stored from the first operation above. - At the end of the aggregation, I don't really care what is in
cat1
andcat2
after the operations above.
- Compute the unique number of values in both
An example output from the example above is:
values start end cat1 cat2
12 1 2 D B
550 1 2 C B
80 1 3 A B
How do I accomplish using pyspark dataframes? I assume I need to use a custom UDAF, right?
回答1:
Pyspark do not support UDAF
directly, so we have to do aggregation manually.
from pyspark.sql import functions as f
def func(values, cat1, cat2):
n = len(set(cat1 + cat2))
return sorted(values)[n - 2]
df = spark.read.load('file:///home/zht/PycharmProjects/test/text_file.txt', format='csv', sep='\t', header=True)
df = df.groupBy(df['start'], df['end']).agg(f.collect_list(df['values']).alias('values'),
f.collect_set(df['cat1']).alias('cat1'),
f.collect_set(df['cat2']).alias('cat2'))
df = df.select(df['start'], df['end'], f.UserDefinedFunction(func, StringType())(df['values'], df['cat1'], df['cat2']))
来源:https://stackoverflow.com/questions/46187630/how-to-write-pyspark-udaf-on-multiple-columns