Split a dataset created by Tensorflow dataset API in to Train and Test?

前端 未结 8 714
天命终不由人
天命终不由人 2020-12-08 02:12

Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?

相关标签:
8条回答
  • 2020-12-08 02:54

    In case size of the dataset is known:

    from typing import Tuple
    import tensorflow as tf
    
    def split_dataset(dataset: tf.data.Dataset, 
                      dataset_size: int, 
                      train_ratio: float, 
                      validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
        assert (train_ratio + validation_ratio) < 1
    
        train_count = int(dataset_size * train_ratio)
        validation_count = int(dataset_size * validation_ratio)
        test_count = dataset_size - (train_count + validation_count)
    
        dataset = dataset.shuffle(dataset_size)
    
        train_dataset = dataset.take(train_count)
        validation_dataset = dataset.skip(train_count).take(validation_count)
        test_dataset = dataset.skip(validation_count + train_count).take(test_count)
    
        return train_dataset, validation_dataset, test_dataset
    

    Example:

    size_of_ds = 1001
    train_ratio = 0.6
    val_ratio = 0.2
    
    ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
    train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
    
    0 讨论(0)
  • 2020-12-08 02:57

    Most of the answers here use take() and skip(), which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.

    Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.

    To accomplish this, lets start with a simple dataset of 0-9:

    dataset = tf.data.Dataset.range(10)
    # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    

    Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.

    split = 3
    dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
    # [0, 1, 2, 4, 5, 6, 8, 9]
    dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
    # [3, 7]
    

    So the first dataset.window(split, split + 1) says to grab split number (3) of elements, then advance split + 1 elements, and repeat. That + 1 effectively skips the 1 element we're going to use in our validation dataset.
    The flat_map(lambda ds: ds) is because window() returns the results in batches, which we don't want. So we flatten it back out.

    Then for the validation data we first skip(split), which skips over the first split number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1) then grabs 1 element, advances split + 1 (4), and repeats.

     

    Note on nested datasets:
    The above example works well for simple datasets, but flat_map() will generate an error if the dataset is nested. To address this, you can swap out the flat_map() with a more complicated version that can handle both simple and nested datasets:

    .flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
    
    0 讨论(0)
提交回复
热议问题