How to perform k-fold cross validation with tensorflow?

后端 未结 2 1575
灰色年华
灰色年华 2021-01-31 04:39

I am following the IRIS example of tensorflow.

My case now is I have all data in a single CSV file, not separated, and I want to apply k-fold cross validation on that da

2条回答
  •  灰色年华
    2021-01-31 05:34

    I know this question is old but in case someone is looking to do something similar, expanding on ahmedhosny's answer:

    The new tensorflow datasets API has the ability to create dataset objects using python generators, so along with scikit-learn's KFold one option can be to create a dataset from the KFold.split() generator:

    import numpy as np
    
    from sklearn.model_selection import LeaveOneOut,KFold
    
    import tensorflow as tf
    import tensorflow.contrib.eager as tfe
    tf.enable_eager_execution()
    
    from sklearn.datasets import load_iris
    data = load_iris()
    X=data['data']
    y=data['target']
    
    def make_dataset(X_data,y_data,n_splits):
    
        def gen():
            for train_index, test_index in KFold(n_splits).split(X_data):
                X_train, X_test = X_data[train_index], X_data[test_index]
                y_train, y_test = y_data[train_index], y_data[test_index]
                yield X_train,y_train,X_test,y_test
    
        return tf.data.Dataset.from_generator(gen, (tf.float64,tf.float64,tf.float64,tf.float64))
    
    dataset=make_dataset(X,y,10)
    

    Then one can iterate through the dataset either in the graph based tensorflow or using eager execution. Using eager execution:

    for X_train,y_train,X_test,y_test in tfe.Iterator(dataset):
        ....
    

提交回复
热议问题