How to generate a custom cross-validation generator in scikit-learn?

前端 未结 4 2102
萌比男神i
萌比男神i 2021-01-31 20:08

I have an unbalanced dataset, so I have an strategy for oversampling that I only apply during training of my data. I\'d like to use classes of scikit-learn like GridSearch

4条回答
  •  暖寄归人
    2021-01-31 20:52

    I had a similar problem and this quick hack is working for me:

    class UpsampleStratifiedKFold:
        def __init__(self, n_splits=3):
            self.n_splits = n_splits
    
        def split(self, X, y, groups=None):
            for rx, tx in StratifiedKFold(n_splits=self.n_splits).split(X,y):
                nix = np.where(y[rx]==0)[0]
                pix = np.where(y[rx]==1)[0]
                pixu = np.random.choice(pix, size=nix.shape[0], replace=True)
                ix = np.append(nix, pixu)
                rxm = rx[ix]
                yield rxm, tx
    
        def get_n_splits(self, X, y, groups=None):
            return self.n_splits
    

    This upsamples (with replacement) the minority class for a balanced (k-1)-fold training set, but leaves kth test set unbalanced. This appears to play well with sklearn.model_selection.GridSearchCV and other similar classes requiring a CV generator.

提交回复
热议问题