Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?
In case size of the dataset is known:
from typing import Tuple
import tensorflow as tf
def split_dataset(dataset: tf.data.Dataset,
dataset_size: int,
train_ratio: float,
validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
assert (train_ratio + validation_ratio) < 1
train_count = int(dataset_size * train_ratio)
validation_count = int(dataset_size * validation_ratio)
test_count = dataset_size - (train_count + validation_count)
dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_count)
validation_dataset = dataset.skip(train_count).take(validation_count)
test_dataset = dataset.skip(validation_count + train_count).take(test_count)
return train_dataset, validation_dataset, test_dataset
Example:
size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2
ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
Most of the answers here use take()
and skip()
, which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.
Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.
To accomplish this, lets start with a simple dataset of 0-9:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
So the first dataset.window(split, split + 1)
says to grab split
number (3) of elements, then advance split + 1
elements, and repeat. That + 1
effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds)
is because window()
returns the results in batches, which we don't want. So we flatten it back out.
Then for the validation data we first skip(split)
, which skips over the first split
number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1)
then grabs 1 element, advances split + 1
(4), and repeats.
Note on nested datasets:
The above example works well for simple datasets, but flat_map()
will generate an error if the dataset is nested. To address this, you can swap out the flat_map()
with a more complicated version that can handle both simple and nested datasets:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))