Stratified sampling with pyspark

前端 未结 4 1276
误落风尘
误落风尘 2020-12-09 13:12

I have a Spark DataFrame that has one column that has lots of zeros and very few ones (only 0.01% of ones).

I\'d like to take a random

4条回答
  •  慢半拍i
    慢半拍i (楼主)
    2020-12-09 13:23

    this is based on the accepted answer of @eliasah and this so thread

    If you want to get back a train and testset you can use the following function:

    from pyspark.sql import functions as F 
    
    def stratified_split_train_test(df, frac, label, join_on, seed=42):
        """ stratfied split of a dataframe in train and test set.
        inspiration gotten from:
        https://stackoverflow.com/a/47672336/1771155
        https://stackoverflow.com/a/39889263/1771155"""
        fractions = df.select(label).distinct().withColumn("fraction", F.lit(frac)).rdd.collectAsMap()
        df_frac = df.stat.sampleBy(label, fractions, seed)
        df_remaining = df.join(df_frac, on=join_on, how="left_anti")
        return df_frac, df_remaining
    

    to create a stratified train and test set where 80% of the total is used for the training set

    df_train, df_test = stratified_split_train_test(df=df, frac=0.8, label="y", join_on="unique_id")
    

提交回复
热议问题