问题
Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?
回答1:
Assuming you have all_dataset
variable of tf.data.Dataset
type:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
Test dataset now has first 1000 elements and the rest goes for training.
回答2:
You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard():
Creates a Dataset that includes only 1/num_shards of this dataset.
Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love
回答3:
Now Tensorflow doesn't contain any tools for that.
You could use sklearn.model_selection.train_test_split
to generate train/eval/test dataset, then create tf.data.Dataset
respectively.
回答4:
@apatsekin,@ted recently i don't have a reputation above 50, so I just have to reply the answer here, i'm wondering it's reasonable to using .take method directly to fetch test dataset or not. if the dataset is shuffled in every epoch, then will it get a different TRAIN/TEST split, as in training process we need the test set should never appear in train set. so should this will be a problem
or we add a parameter in shuffle:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle( reshuffle_each_iteration = False )
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
回答5:
You can use shard
:
dataset = dataset.shuffle() # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
See: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
来源:https://stackoverflow.com/questions/48213766/split-a-dataset-created-by-tensorflow-dataset-api-in-to-train-and-test