[Edit #1 after @mrry comment] I am using the (great & amazing) Dataset API along with tf.contrib.data.rejection_resample to set a specific distribution
Following @mrry response I could come up with a solution on how to use the Dataset API with tf.contrib.data.rejection_resample (using TF1.3).
The goal
Given a feature/label dataset with some distribution, have the input pipeline reshape the distribution to specific target distribution.
Numerical example
Lets assume we are building a network to classify some feature into one of 10 classes.
And assume we only have 100 features with some random distribution of labels.
30 features labeled as class 1, 5 features labeled as class 2
and so forth.
During training we do not want to prefer class 1 over class 2 so we would like each mini-batch to hold a uniform distribution for all classes.
The solution
Using tf.contrib.data.rejection_resample will allow to set a specific distribution for our inputs pipelines.
In the documentation it says tf.contrib.data.rejection_resample will take
(1) Dataset - which is the dataset you want to balance
(2) class_func - which is a function that generates a new numerical labels dataset only from the original dataset
(3) target_dist - a vector in the size of the number of classes to specificy required new distribution.
(4) some more optional values - skipped for now
and as the documentation says it returns a `Dataset.
It turns out that the shape of the input Dataset is different than the output Dataset shape. As a consequence, the returned Dataset (as implemeted in TF1.3) should be filtered by the user like this:
balanced_dataset = tf.contrib.data.rejection_resample(input_dataset,
self.class_mapping_function,
self.target_distribution)
# Return to the same Dataset shape as was the original input
balanced_dataset = balanced_dataset.map(lambda _, data: (data))
One note on the Iterator kind. As @mrry explained here, when using stateful objects within the pipeline one should use the initializable iterator and not the one-hot. Note that when using the initializable iterator you should add the init_op to the TABLE_INITIALIZERS or you will recieve this error: "GetNext() failed because the iterator has not been initialized."
Code example:
# Creating the iterator, that allows to access elements from the dataset
if self.use_balancing:
# For balancing function, we use stateful variables in the sense that they hold current dataset distribution
# and calculate next distribution according to incoming examples.
# For dataset pipeline that have state, one_shot iterator will not work, and we are forced to use
# initializable iterator
# This should be relaxed in the future.
# https://stackoverflow.com/questions/44374083/tensorflow-cannot-capture-a-stateful-node-by-value-in-tf-contrib-data-api
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
else:
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
Does it work ?
Yes. Here are 2 images from Tensorboard after collection a histogram on the input pipeline labels. The original input labels were uniformly distributed. Scenario A: Trying to achieve the following 10-class distribution: [0.1,0.4,0.05,0.05,0.05,0.05,0.05,0.05,0.1,0.1]
And the result:
Scenario B: Trying to achieve the following 10-class distribution: [0.1,0.1,0.05,0.05,0.05,0.05,0.05,0.05,0.4,0.1]
And the result: