I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable
TL;DR: Replace the definition of epoch_counter
with the following:
epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
trainable=False, use_resource=True)
There are some limitations around using TensorFlow variables inside tf.data.Dataset
transformations. The principle limitation is that all variables must be "resource variables" and not the older "reference variables"; unfortunately tf.Variable
still creates "reference variables" for backwards compatibility reasons.
Generally speaking, I wouldn't recommend using variables in a tf.data
pipeline if it's possible to avoid it. For example, you might be able to use Dataset.range()
to define an epoch counter, and then do something like:
epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
(pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))
The above snippet attaches an epoch counter to every value as a second component.