Split a dataset created by Tensorflow dataset API in to Train and Test?

前端 未结 8 719
天命终不由人
天命终不由人 2020-12-08 02:12

Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?

8条回答
  •  轮回少年
    2020-12-08 02:37

    You may use Dataset.take() and Dataset.skip():

    train_size = int(0.7 * DATASET_SIZE)
    val_size = int(0.15 * DATASET_SIZE)
    test_size = int(0.15 * DATASET_SIZE)
    
    full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
    full_dataset = full_dataset.shuffle()
    train_dataset = full_dataset.take(train_size)
    test_dataset = full_dataset.skip(train_size)
    val_dataset = test_dataset.skip(val_size)
    test_dataset = test_dataset.take(test_size)
    

    For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

    Take:

    Creates a Dataset with at most count elements from this dataset.

    Skip:

    Creates a Dataset that skips count elements from this dataset.

    You may also want to look into Dataset.shard():

    Creates a Dataset that includes only 1/num_shards of this dataset.


    Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love

提交回复
热议问题