I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable
I extended the example code of numerica to batches and replaced the itertool part:
num_examples = 5
num_epochs = 4
batch_size = 2
num_iters = int(num_examples * num_epochs / batch_size)
features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)
epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
lambda i: tf.data.Dataset.zip((
data,
tf.data.Dataset.from_tensors(i).repeat(),
tf.data.Dataset.range(num_examples)
))
)
# to flatten the nested datasets
data = data.map(lambda samples, *cnts: samples+cnts )
data = data.batch(batch_size=batch_size)
it = data.make_one_shot_iterator()
x, y, ep, st = it.get_next()
with tf.Session() as sess:
for _ in range(num_iters):
x_, y_, ep_, st_ = sess.run([x, y, ep, st])
print(f'step {st_}\t epoch {ep_} \t x {x_} \t y {y_}')