Scikit-learn balanced subsampling

前端 未结 13 1589
终归单人心
终归单人心 2020-12-02 10:34

I\'m trying to create N balanced random subsamples of my large unbalanced dataset. Is there a way to do this simply with scikit-learn / pandas or do I have to implement it m

13条回答
  •  [愿得一人]
    2020-12-02 10:58

    Here is a version of the above code that works for multiclass groups (in my tested case group 0, 1, 2, 3, 4)

    import numpy as np
    def balanced_sample_maker(X, y, sample_size, random_seed=None):
        """ return a balanced data set by sampling all classes with sample_size 
            current version is developed on assumption that the positive
            class is the minority.
    
        Parameters:
        ===========
        X: {numpy.ndarrray}
        y: {numpy.ndarray}
        """
        uniq_levels = np.unique(y)
        uniq_counts = {level: sum(y == level) for level in uniq_levels}
    
        if not random_seed is None:
            np.random.seed(random_seed)
    
        # find observation index of each class levels
        groupby_levels = {}
        for ii, level in enumerate(uniq_levels):
            obs_idx = [idx for idx, val in enumerate(y) if val == level]
            groupby_levels[level] = obs_idx
        # oversampling on observations of each label
        balanced_copy_idx = []
        for gb_level, gb_idx in groupby_levels.iteritems():
            over_sample_idx = np.random.choice(gb_idx, size=sample_size, replace=True).tolist()
            balanced_copy_idx+=over_sample_idx
        np.random.shuffle(balanced_copy_idx)
    
        return (X[balanced_copy_idx, :], y[balanced_copy_idx], balanced_copy_idx)
    

    This also returns the indices so they can be used for other datasets and to keep track of how frequently each data set was used (helpful for training)

提交回复
热议问题