Scikit-learn balanced subsampling

前端 未结 13 1639
终归单人心
终归单人心 2020-12-02 10:34

I\'m trying to create N balanced random subsamples of my large unbalanced dataset. Is there a way to do this simply with scikit-learn / pandas or do I have to implement it m

13条回答
  •  不知归路
    2020-12-02 11:18

    Here is my solution, which can be tightly integrated in an existing sklearn pipeline:

    from sklearn.model_selection import RepeatedKFold
    import numpy as np
    
    
    class DownsampledRepeatedKFold(RepeatedKFold):
    
        def split(self, X, y=None, groups=None):
            for i in range(self.n_repeats):
                np.random.seed()
                # get index of major class (negative)
                idxs_class0 = np.argwhere(y == 0).ravel()
                # get index of minor class (positive)
                idxs_class1 = np.argwhere(y == 1).ravel()
                # get length of minor class
                len_minor = len(idxs_class1)
                # subsample of major class of size minor class
                idxs_class0_downsampled = np.random.choice(idxs_class0, size=len_minor)
                original_indx_downsampled = np.hstack((idxs_class0_downsampled, idxs_class1))
                np.random.shuffle(original_indx_downsampled)
                splits = list(self.cv(n_splits=self.n_splits, shuffle=True).split(original_indx_downsampled))
    
                for train_index, test_index in splits:
                    yield original_indx_downsampled[train_index], original_indx_downsampled[test_index]
    
        def __init__(self, n_splits=5, n_repeats=10, random_state=None):
            self.n_splits = n_splits
             super(DownsampledRepeatedKFold, self).__init__(
            n_splits=n_splits, n_repeats=n_repeats, random_state=random_state
        )
    

    Use it as usual:

        for train_index, test_index in DownsampledRepeatedKFold(n_splits=5, n_repeats=10).split(X, y):
             X_train, X_test = X[train_index], X[test_index]
             y_train, y_test = y[train_index], y[test_index]
    

提交回复
热议问题