How to select a same-size stratified sample from a dataframe in Apache Spark?

吃可爱长大的小学妹 提交于 2019-12-05 23:16:27

问题


I have a dataframe in Spark 2 as shown below where users have between 50 to thousands of posts. I would like to create a new dataframe that will have all the users in the original dataframe but with only 5 randomly sampled posts for each user.

+--------+--------------+--------------------+
| user_id|       post_id|                text|
+--------+--------------+--------------------+
|67778705|44783131591473|some text...........|
|67778705|44783134580755|some text...........|
|67778705|44783136367108|some text...........|
|67778705|44783136970669|some text...........|
|67778705|44783138143396|some text...........|
|67778705|44783155162624|some text...........|
|67778705|44783688650554|some text...........|
|68950272|88655645825660|some text...........|
|68950272|88651393135293|some text...........|
|68950272|88652615409812|some text...........|
|68950272|88655744880460|some text...........|
|68950272|88658059871568|some text...........|
|68950272|88656994832475|some text...........|
+--------+--------------+--------------------+

Something like posts.groupby('user_id').agg(sample('post_id')) but there is no such function in pyspark.

Any advice?

Update:

This question is different from another closely related question stratified-sampling-in-spark in two ways:

  1. It asks about disproportionate stratified sampling rather than the common proportionate method in the other question above.
  2. It asks about doing this in Spark's Dataframe API rather than RDD.

I have also updated the question's title to clarify this.


回答1:


Using sampleBy will result in approximate solution. Here is an alternative approach that is a little more hacky than the approach above but always results in exactly the same sample sizes.

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

df.withColumn("row_num",row_number().over(Window.partitionBy($"user_id").orderBy($"something_random"))

If you don't already have a random ID then you can use org.apache.spark.sql.functions.rand to create a column with a random value to guarantee your random sampling.




回答2:


You can use the .sampleBy(...) method for DataFrames http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy

Here's a working example:

import numpy as np
import string
import random

# generate some fake data
p = [(
    str(int(e)), 
    ''.join(
        random.choice(
            string.ascii_uppercase + string.digits) 
        for _ in range(10)
    )
) for e in np.random.normal(10, 1, 10000)]

posts = spark.createDataFrame(p, ['label', 'val'])

# define the sample size
percent_back = 0.05

# use this if you want an (almost) exact number of samples
# sample_count = 200
# percent_back = sample_count / posts.count()

frac = dict(
    (e.label, percent_back) 
    for e 
    in posts.select('label').distinct().collect()
)

# use this if you want (almost) balanced sample
# f = posts.groupby('label').count()

# f_min_count can also be specified to be exact number 

# e.g. f_min_count = 5

# as long as it is less the the minimum count of posts per user
# calculated from all the users

# alternatively, you can take the minimum post count
# f_min_count = f.select('count').agg(func.min('count').alias('minVal')).collect()[0].minVal

# f = f.withColumn('frac',f_min_count/func.col('count'))

# frac = dict(f.select('label', 'frac').collect())

# sample the data
sampled = posts.sampleBy('label', fractions=frac)

# compare the original counts with sampled
original_total_count = posts.count()
original_counts = posts.groupby('label').count()
original_counts = original_counts \
    .withColumn('count_perc', 
                original_counts['count'] / original_total_count)

sampled_total_count = sampled.count()
sampled_counts = sampled.groupBy('label').count()
sampled_counts = sampled_counts \
    .withColumn('count_perc', 
                sampled_counts['count'] / sampled_total_count)


print(original_counts.sort('label').show(100))
print(sampled_counts.sort('label').show(100))

print(sampled_total_count)
print(sampled_total_count / original_total_count)


来源:https://stackoverflow.com/questions/41516805/how-to-select-a-same-size-stratified-sample-from-a-dataframe-in-apache-spark

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!