How to calculate mean and standard deviation given a PySpark DataFrame?

后端 未结 3 1096
礼貌的吻别
礼貌的吻别 2021-02-07 14:29

I have PySpark DataFrame (not pandas) called df that is quite large to use collect(). Therefore the below-given code is not efficient.

3条回答
  •  天命终不由人
    2021-02-07 15:26

    You can use mean and stddev from pyspark.sql.functions:

    import pyspark.sql.functions as F
    
    df = spark.createDataFrame(
        [(680, [[691,1], [692,5]]), (685, [[691,2], [692,2]]), (684, [[691,1], [692,3]])],
        ["product_PK", "products"]
    )
    
    result_df = (
        df
        .withColumn(
            'val_list',
            F.array(df.products.getItem(0).getItem(1),df.products.getItem(1).getItem(1))
        )
        .select(F.explode('val_list').alias('val'))
        .select(F.mean('val').alias('mean'), F.stddev('val').alias('stddev'))
    )
    
    print(result_df.collect())
    

    which outputs:

    [Row(mean=2.3333333333333335, stddev=1.505545305418162)]
    

    You can read more about pyspark.sql.functions here.

提交回复
热议问题