Scikit-learn train_test_split with indices

后端 未结 5 945
忘了有多久
忘了有多久 2020-12-04 14:07

How do I get the original indices of the data when using train_test_split()?

What I have is the following

from sklearn.cross_validation import train_         


        
5条回答
  •  一生所求
    2020-12-04 14:41

    The docs mention train_test_split is just a convenience function on top of shuffle split.

    I just rearranged some of their code to make my own example. Note the actual solution is the middle block of code. The rest is imports, and setup for a runnable example.

    from sklearn.model_selection import ShuffleSplit
    from sklearn.utils import safe_indexing, indexable
    from itertools import chain
    import numpy as np
    X = np.reshape(np.random.randn(20),(10,2)) # 10 training examples
    y = np.random.randint(2, size=10) # 10 labels
    seed = 1
    
    cv = ShuffleSplit(random_state=seed, test_size=0.25)
    arrays = indexable(X, y)
    train, test = next(cv.split(X=X))
    iterator = list(chain.from_iterable((
        safe_indexing(a, train),
        safe_indexing(a, test),
        train,
        test
        ) for a in arrays)
    )
    X_train, X_test, train_is, test_is, y_train, y_test, _, _  = iterator
    
    print(X)
    print(train_is)
    print(X_train)
    

    Now I have the actual indexes: train_is, test_is

提交回复
热议问题