How to select a same-size stratified sample from a dataframe in Apache Spark?

别说谁变了你拦得住时间么 提交于 2019-12-04 03:55:11

Using sampleBy will result in approximate solution. Here is an alternative approach that is a little more hacky than the approach above but always results in exactly the same sample sizes.

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

df.withColumn("row_num",row_number().over(Window.partitionBy($"user_id").orderBy($"something_random"))

If you don't already have a random ID then you can use org.apache.spark.sql.functions.rand to create a column with a random value to guarantee your random sampling.

You can use the .sampleBy(...) method for DataFrames http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy

Here's a working example:

import numpy as np
import string
import random

# generate some fake data
p = [(
    str(int(e)), 
    ''.join(
        random.choice(
            string.ascii_uppercase + string.digits) 
        for _ in range(10)
    )
) for e in np.random.normal(10, 1, 10000)]

posts = spark.createDataFrame(p, ['label', 'val'])

# define the sample size
percent_back = 0.05

# use this if you want an (almost) exact number of samples
# sample_count = 200
# percent_back = sample_count / posts.count()

frac = dict(
    (e.label, percent_back) 
    for e 
    in posts.select('label').distinct().collect()
)

# use this if you want (almost) balanced sample
# f = posts.groupby('label').count()

# f_min_count can also be specified to be exact number 

# e.g. f_min_count = 5

# as long as it is less the the minimum count of posts per user
# calculated from all the users

# alternatively, you can take the minimum post count
# f_min_count = f.select('count').agg(func.min('count').alias('minVal')).collect()[0].minVal

# f = f.withColumn('frac',f_min_count/func.col('count'))

# frac = dict(f.select('label', 'frac').collect())

# sample the data
sampled = posts.sampleBy('label', fractions=frac)

# compare the original counts with sampled
original_total_count = posts.count()
original_counts = posts.groupby('label').count()
original_counts = original_counts \
    .withColumn('count_perc', 
                original_counts['count'] / original_total_count)

sampled_total_count = sampled.count()
sampled_counts = sampled.groupBy('label').count()
sampled_counts = sampled_counts \
    .withColumn('count_perc', 
                sampled_counts['count'] / sampled_total_count)


print(original_counts.sort('label').show(100))
print(sampled_counts.sort('label').show(100))

print(sampled_total_count)
print(sampled_total_count / original_total_count)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!