Sklearn StratifiedKFold: ValueError: Supported target types are: ('binary', 'multiclass'). Got 'multilabel-indicator' instead

前端 未结 4 2150
攒了一身酷
攒了一身酷 2020-12-17 08:36

Working with Sklearn stratified kfold split, and when I attempt to split using multi-class, I received on error (see below). When I tried and split using binary, it works n

4条回答
  •  春和景丽
    2020-12-17 08:55

    In my case, x was a 2D matrix, and y was also a 2d matrix, i.e. indeed a multi-class multi-output case. I just passed a dummy np.zeros(shape=(n,1)) for the y and the x as usual. Full code example:

    import numpy as np
    from sklearn.model_selection import RepeatedStratifiedKFold
    X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 7], [9, 4]])
    # y = np.array([0, 0, 1, 1, 0, 1]) # <<< works
    y = X # does not work if passed into `.split`
    rskf = RepeatedStratifiedKFold(n_splits=3, n_repeats=3, random_state=36851234)
    for train_index, test_index in rskf.split(X, np.zeros(shape=(X.shape[0], 1))):
        print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    

提交回复
热议问题