How to get current global_step in data pipeline

I am trying to create a filter which depends on the current global_step of the training but I am failing to do so properly.

First, I cannot use tf.train.get_or_create_global_step() in the code below because it will throw

ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

This is why I tried fetching the scope with tf.get_default_graph().get_name_scope() and within that context I was able to "get" the global step:

def filter_examples(example):
    scope = tf.get_default_graph().get_name_scope()

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        current_step = tf.train.get_or_create_global_step()

    subtokens_by_step = tf.floor(current_step / curriculum_step_update)
    max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)

    return tf.size(example['targets']) <= max_subtokens

dataset = dataset.filter(filter_examples)

The problem with this is that it does not seem to work as I expected. From what I am observing, the current_step in the code above seems to be 0 all the time (I don't know that, just based on my observations I assume that).

The only thing that seems to make a difference, and it sounds weird, is restarting the training. I think, also based on observations, in that case current_step will be the actual current step of the training at this point. But the value itself won't update as the training continues.

If there a way to get the actual value of the current step and use it in my filter like above?


Tensorflow 1.12.1


As we discussed in the comments, having and updating your own counter might be an alternative to using the global_step variable. The counter variable could be updated as follows:

op = tf.assign_add(counter, 1)
with tf.control_dependencies(op): 
    # Some operation here before which the counter should be updated

Using tf.control_dependencies allows to "attach" the update of counter to a path within the computational graph. You can then use the counter variable wherever you need it.


If you use variables inside datasets you need to reinitilize iterators in tf 1.x.

iterator = tf.compat.v1.make_initializable_iterator(dataset)
init = iterator.initializer
tensors = iterator.get_next()

with tf.compat.v1.Session() as sess:
    for epoch in range(num_epochs):
        for example in range(num_examples):
            tensor_vals =

