Fill Pyspark dataframe column null values with average value from same column

走远了吗. 提交于 2019-12-01 15:45:36

问题


With a dataframe like this,

rdd_2 = sc.parallelize([(0,10,223,"201601"), (0,10,83,"2016032"),(1,20,None,"201602"),(1,20,3003,"201601"), (1,20,None,"201603"), (2,40, 2321,"201601"), (2,30, 10,"201602"),(2,61, None,"201601")])

df_data = sqlContext.createDataFrame(rdd_2, ["id", "type", "cost", "date"])
df_data.show()

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|null| 201602|
|  1|  20|3003| 201601|
|  1|  20|null| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|null| 201601|
+---+----+----+-------+

I need to fill the null values with the average of the existing values, with the expected result being

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

where 1128 is the average of the existing values. I need to do that for several columns.

My current approach is to use na.fill:

fill_values = {column: df_data.agg({column:"mean"}).flatMap(list).collect()[0] for column in df_data.columns if column not in ['date','id']}
df_data = df_data.na.fill(fill_values)

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

But this is very cumbersome. Any ideas?


回答1:


Well, one way or another you have to:

  • compute statistics
  • fill the blanks

It pretty much limits what you can really improve here, still:

  • replace flatMap(list).collect()[0] with first()[0] or structure unpacking
  • compute all stats with a single action
  • use built-in Row methods to extract dictionary

The final result could like this:

def fill_with_mean(df, exclude=set()): 
    stats = df.agg(*(
        avg(c).alias(c) for c in df.columns if c not in exclude
    ))
    return df.na.fill(stats.first().asDict())

fill_with_mean(df_data, ["id", "date"])

In Spark 2.2 or later you can also use Imputer. See Replace missing values with mean - Spark Dataframe.



来源:https://stackoverflow.com/questions/37749759/fill-pyspark-dataframe-column-null-values-with-average-value-from-same-column

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!