Epoch counter with TensorFlow Dataset API

前端 未结 4 875
悲&欢浪女
悲&欢浪女 2020-12-19 01:27

I\'m changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable

4条回答
  •  南方客
    南方客 (楼主)
    2020-12-19 02:23

    TL;DR: Replace the definition of epoch_counter with the following:

    epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                    trainable=False, use_resource=True)
    

    There are some limitations around using TensorFlow variables inside tf.data.Dataset transformations. The principle limitation is that all variables must be "resource variables" and not the older "reference variables"; unfortunately tf.Variable still creates "reference variables" for backwards compatibility reasons.

    Generally speaking, I wouldn't recommend using variables in a tf.data pipeline if it's possible to avoid it. For example, you might be able to use Dataset.range() to define an epoch counter, and then do something like:

    epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
    dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
        (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))
    

    The above snippet attaches an epoch counter to every value as a second component.

提交回复
热议问题