Oversampling functionality in Tensorflow dataset API

旧街凉风 提交于 2019-11-30 20:21:27
K Kolasinski

This problem has been solved in issue #14451. Just posting the anwser here to make it more visible to other developers.

The sample code is oversampling low frequent classes and undersampling high frequent ones, where class_target_prob is just uniform distribution in my case. I wanted to check some conclusions from recent manuscript A systematic study of the class imbalance problem in convolutional neural networks

The oversampling of specific classes is done by calling:

dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

Here is the full snippet which does all the things:

# sampling parameters
oversampling_coef = 0.9  # if equal to 0 then oversample_classes() always returns 1
undersampling_coef = 0.5  # if equal to 0 then undersampling_filter() always returns True

def oversample_classes(example):
    """
    Returns the number of copies of given example
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    # soften ratio is oversampling_coef==0 we recover original distribution
    prob_ratio = prob_ratio ** oversampling_coef 
    # for classes with probability higher than class_target_prob we
    # want to return 1
    prob_ratio = tf.maximum(prob_ratio, 1) 
    # for low probability classes this number will be very large
    repeat_count = tf.floor(prob_ratio)
    # prob_ratio can be e.g 1.9 which means that there is still 90%
    # of change that we should return 2 instead of 1
    repeat_residual = prob_ratio - repeat_count # a number between 0-1
    residual_acceptance = tf.less_equal(
                        tf.random_uniform([], dtype=tf.float32), repeat_residual
    )

    residual_acceptance = tf.cast(residual_acceptance, tf.int64)
    repeat_count = tf.cast(repeat_count, dtype=tf.int64)

    return repeat_count + residual_acceptance


def undersampling_filter(example):
    """
    Computes if given example is rejected or not.
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    prob_ratio = prob_ratio ** undersampling_coef
    prob_ratio = tf.minimum(prob_ratio, 1.0)

    acceptance = tf.less_equal(tf.random_uniform([], dtype=tf.float32), prob_ratio)

    return acceptance


dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

dataset = dataset.filter(undersampling_filter)

dataset = dataset.repeat(-1)
dataset = dataset.shuffle(2048)
dataset = dataset.batch(32)

sess.run(tf.global_variables_initializer())

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

Update #1

Here is a simple jupyter notebook which implements the above oversampling/undersampling on a toy model.

tf.data.experimental.rejection_resample seems to be a better way, since it does not require "class_prob" and "class_target_prob" feature.
Although it is under-sampling instead of over-sampling, with the same target distribution and training steps, it would work the same.

This QnA was so helpful for me. So I wrote a blog post about it with my related experience.

https://vallum.github.io/Optimizing_parallel_performance_of_resampling_with_tensorflow.html

I hope someone who is interested in Tensorflow input pipeline optimization with re-sampling might get some idea from it.

Some ops are probably unnecessarily redundant but were not too big performance degraders in my personal case.

 dataset = dataset.map(undersample_filter_fn, num_parallel_calls=num_parallel_calls) 
 dataset = dataset.flat_map(lambda x : x) 

flat_map with the identity lambda function is just for merging survived (and empty) records

# Pseudo-code for understanding of flat_map after maps
#parallel calls of map('A'), map('B'), and map('C')
map('A') = 'AAAAA' # replication of A 5 times
map('B') = ''      # B is dropped
map('C') = 'CC'    # replication of C twice
# merging all map results
flat_map('AAAA,,CC') = 'AAAACC'
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!