Scikit-learn, GroupKFold with shuffling groups?

被刻印的时光 ゝ 提交于 2019-12-04 09:31:24

问题


I was using StratifiedKFold from scikit-learn, but now I need to watch also for "groups". There is nice function GroupKFold, but my data are very time dependent. So similary as in help, ie number of week is the grouping index. But each week should be only in one fold.

Suppose I need 10 folds. What I need is to shuffle data first, before I can used GroupKFold.

Shuffling is in group sence - so whole groups should be shuffle among each other.

Is there way to do is with scikit-learn elegant somehow? Seems to me GroupKFold is robust to shuffle data first.

If there is no way to do it with scikit, can anyone write some effective code of this? I have large data sets.

matrix, label, groups as inputs


回答1:


EDIT: This solution does not work.

I think using sklearn.utils.shuffle is an elegant solution!

For data in X, y and groups:

from sklearn.utils import shuffle
X_shuffled, y_shuffled, groups_shuffled = shuffle(X, y, groups, random_state=0)

Then use X_shuffled, y_shuffled and groups_shuffled with GroupKFold:

from sklearn.model_selection import GroupKFold
group_k_fold = GroupKFold(n_splits=10)
splits = group_k_fold.split(X_shuffled, y_shuffled, groups_shuffled)

Of course, you probably want to shuffle multiple times and do the cross-validation with each shuffle. You could put the entire thing in a loop - here's a complete example with 5 shuffles (and only 3 splits instead of your required 10):

X = np.arange(20).reshape((10, 2))
y = np.arange(10)
groups = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7]

n_shuffles = 5
group_k_fold = GroupKFold(n_splits=3)

for i in range(n_shuffles):
    X_shuffled, y_shuffled, groups_shuffled = shuffle(X, y, groups, random_state=i)
    splits = group_k_fold.split(X_shuffled, y_shuffled, groups_shuffled)
    # do something with splits here, I'm just printing them out
    print 'Shuffle', i
    print 'groups_shuffled:', groups_shuffled
    for train_idx, val_idx in splits:
        print 'Train:', train_idx
        print 'Val:', val_idx


来源:https://stackoverflow.com/questions/40819598/scikit-learn-groupkfold-with-shuffling-groups

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!