Epoch counter with TensorFlow Dataset API

前端 未结 4 877
悲&欢浪女
悲&欢浪女 2020-12-19 01:27

I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable

4条回答
  •  遥遥无期
    2020-12-19 02:16

    To add to @mrry's great answer, if you want to stay within the tf.data pipeline and also want to track the iteration within each epoch you can try my solution below. If you have non-unit batch size I guess you would have to add the line data = data.batch(bs).

    import tensorflow as tf
    import itertools
    
    def step_counter(): 
        for i in itertools.count(): yield i
    
    num_examples = 3
    num_epochs = 2
    num_iters = num_examples * num_epochs
    
    features = tf.data.Dataset.range(num_examples)
    labels = tf.data.Dataset.range(num_examples)
    data = tf.data.Dataset.zip((features, labels))
    data = data.shuffle(num_examples)
    
    step = tf.data.Dataset.from_generator(step_counter, tf.int32)
    data = tf.data.Dataset.zip((data, step))
    
    epoch = tf.data.Dataset.range(num_epochs)
    data = epoch.flat_map(
        lambda i: tf.data.Dataset.zip(
            (data, tf.data.Dataset.from_tensors(i).repeat())))
    
    data = data.repeat(num_epochs)
    it = data.make_one_shot_iterator()
    example = it.get_next()
    
    with tf.Session() as sess:
        for _ in range(num_iters):
            ((x, y), st), ep = sess.run(example)
            print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')
    

    Prints:

    step 0   epoch 0     x 2     y 2
    step 1   epoch 0     x 0     y 0
    step 2   epoch 0     x 1     y 1
    step 0   epoch 1     x 2     y 2
    step 1   epoch 1     x 0     y 0
    step 2   epoch 1     x 1     y 1
    

提交回复
热议问题