Epoch counter with TensorFlow Dataset API

前端 未结 4 887
悲&欢浪女
悲&欢浪女 2020-12-19 01:27

I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable

4条回答
  •  难免孤独
    2020-12-19 02:21

    I extended the example code of numerica to batches and replaced the itertool part:

    num_examples = 5
    num_epochs = 4
    batch_size = 2
    num_iters = int(num_examples * num_epochs / batch_size)
    
    features = tf.data.Dataset.range(num_examples)
    labels = tf.data.Dataset.range(num_examples)
    
    data = tf.data.Dataset.zip((features, labels))
    data = data.shuffle(num_examples)
    
    epoch = tf.data.Dataset.range(num_epochs)
    data = epoch.flat_map(
        lambda i: tf.data.Dataset.zip((
            data,
            tf.data.Dataset.from_tensors(i).repeat(),
            tf.data.Dataset.range(num_examples)
        ))
    )
    
    # to flatten the nested datasets
    data = data.map(lambda samples, *cnts: samples+cnts )
    data = data.batch(batch_size=batch_size)
    
    it = data.make_one_shot_iterator()
    x, y, ep, st = it.get_next()
    
    with tf.Session() as sess:
        for _ in range(num_iters):
            x_, y_, ep_, st_ = sess.run([x, y, ep, st])
            print(f'step {st_}\t epoch {ep_} \t x {x_} \t y {y_}')
    

提交回复
热议问题