Stratified Sampling in Pandas

后端 未结 3 1591
慢半拍i
慢半拍i 2020-12-12 22:17

I\'ve looked at the Sklearn stratified sampling docs as well as the pandas docs and also Stratified samples from Pandas and sklearn stratified sampling based on a column but

相关标签:
3条回答
  • 2020-12-12 23:06

    Use min when passing the number to sample. Consider the dataframe df

    df = pd.DataFrame(dict(
            A=[1, 1, 1, 2, 2, 2, 2, 3, 4, 4],
            B=range(10)
        ))
    
    df.groupby('A', group_keys=False).apply(lambda x: x.sample(min(len(x), 2)))
    
       A  B
    1  1  1
    2  1  2
    3  2  3
    6  2  6
    7  3  7
    9  4  9
    8  4  8
    
    0 讨论(0)
  • 2020-12-12 23:13

    Extending the groupby answer, we can make sure that sample is balanced. To do so, when for all classes the number of samples is >= n_samples, we can just take n_samples for all classes (previous answer). When minority class contains < n_samples, we can take the number of samples for all classes to be the same as of minority class.

    def stratified_sample_df(df, col, n_samples):
        n = min(n_samples, df[col].value_counts().min())
        df_ = df.groupby(col).apply(lambda x: x.sample(n))
        df_.index = df_.index.droplevel(0)
        return df_
    
    0 讨论(0)
  • 2020-12-12 23:16

    the following sample a total of N row where each group appear in its original proportion to the nearest integer, then shuffle and reset the index using:

    df = pd.DataFrame(dict(
        A=[1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4],
        B=range(20)
    ))
    

    Short and sweet:

    df.sample(n=N, weights='A', random_state=1).reset_index(drop=True)
    

    Long version

    df.groupby('A', group_keys=False).apply(lambda x: x.sample(int(np.rint(N*len(x)/len(df))))).sample(frac=1).reset_index(drop=True)
    
    0 讨论(0)
提交回复
热议问题