Split a dataset created by Tensorflow dataset API in to Train and Test?

前端 未结 8 715
天命终不由人
天命终不由人 2020-12-08 02:12

Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?

8条回答
  •  無奈伤痛
    2020-12-08 02:54

    In case size of the dataset is known:

    from typing import Tuple
    import tensorflow as tf
    
    def split_dataset(dataset: tf.data.Dataset, 
                      dataset_size: int, 
                      train_ratio: float, 
                      validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
        assert (train_ratio + validation_ratio) < 1
    
        train_count = int(dataset_size * train_ratio)
        validation_count = int(dataset_size * validation_ratio)
        test_count = dataset_size - (train_count + validation_count)
    
        dataset = dataset.shuffle(dataset_size)
    
        train_dataset = dataset.take(train_count)
        validation_dataset = dataset.skip(train_count).take(validation_count)
        test_dataset = dataset.skip(validation_count + train_count).take(test_count)
    
        return train_dataset, validation_dataset, test_dataset
    

    Example:

    size_of_ds = 1001
    train_ratio = 0.6
    val_ratio = 0.2
    
    ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
    train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
    

提交回复
热议问题