Oversampling functionality in Tensorflow dataset API

别来无恙 提交于 2019-11-30 04:22:22

问题


I would like to ask if current API of datasets allows for implementation of oversampling algorithm? I deal with highly imbalanced class problem. I was thinking that it would be nice to oversample specific classes during dataset parsing i.e. online generation. I've seen the implementation for rejection_resample function, however this removes samples instead of duplicating them and its slows down batch generation (when target distribution is much different then initial one). The thing I would like to achieve is: to take an example, look at its class probability decide if duplicate it or not. Then call dataset.shuffle(...) dataset.batch(...) and get iterator. The best (in my opinion) approach would be to oversample low probable classes and subsample most probable ones. I would like to do it online since it's more flexible.


回答1:


This problem has been solved in issue #14451. Just posting the anwser here to make it more visible to other developers.

The sample code is oversampling low frequent classes and undersampling high frequent ones, where class_target_prob is just uniform distribution in my case. I wanted to check some conclusions from recent manuscript A systematic study of the class imbalance problem in convolutional neural networks

The oversampling of specific classes is done by calling:

dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

Here is the full snippet which does all the things:

# sampling parameters
oversampling_coef = 0.9  # if equal to 0 then oversample_classes() always returns 1
undersampling_coef = 0.5  # if equal to 0 then undersampling_filter() always returns True

def oversample_classes(example):
    """
    Returns the number of copies of given example
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    # soften ratio is oversampling_coef==0 we recover original distribution
    prob_ratio = prob_ratio ** oversampling_coef 
    # for classes with probability higher than class_target_prob we
    # want to return 1
    prob_ratio = tf.maximum(prob_ratio, 1) 
    # for low probability classes this number will be very large
    repeat_count = tf.floor(prob_ratio)
    # prob_ratio can be e.g 1.9 which means that there is still 90%
    # of change that we should return 2 instead of 1
    repeat_residual = prob_ratio - repeat_count # a number between 0-1
    residual_acceptance = tf.less_equal(
                        tf.random_uniform([], dtype=tf.float32), repeat_residual
    )

    residual_acceptance = tf.cast(residual_acceptance, tf.int64)
    repeat_count = tf.cast(repeat_count, dtype=tf.int64)

    return repeat_count + residual_acceptance


def undersampling_filter(example):
    """
    Computes if given example is rejected or not.
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    prob_ratio = prob_ratio ** undersampling_coef
    prob_ratio = tf.minimum(prob_ratio, 1.0)

    acceptance = tf.less_equal(tf.random_uniform([], dtype=tf.float32), prob_ratio)

    return acceptance


dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

dataset = dataset.filter(undersampling_filter)

dataset = dataset.repeat(-1)
dataset = dataset.shuffle(2048)
dataset = dataset.batch(32)

sess.run(tf.global_variables_initializer())

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

Update #1

Here is a simple jupyter notebook which implements the above oversampling/undersampling on a toy model.




回答2:


tf.data.experimental.rejection_resample seems to be a better way, since it does not require "class_prob" and "class_target_prob" feature.
Although it is under-sampling instead of over-sampling, with the same target distribution and training steps, it would work the same.




回答3:


This QnA was so helpful for me. So I wrote a blog post about it with my related experience.

https://vallum.github.io/Optimizing_parallel_performance_of_resampling_with_tensorflow.html

I hope someone who is interested in Tensorflow input pipeline optimization with re-sampling might get some idea from it.

Some ops are probably unnecessarily redundant but were not too big performance degraders in my personal case.

 dataset = dataset.map(undersample_filter_fn, num_parallel_calls=num_parallel_calls) 
 dataset = dataset.flat_map(lambda x : x) 

flat_map with the identity lambda function is just for merging survived (and empty) records

# Pseudo-code for understanding of flat_map after maps
#parallel calls of map('A'), map('B'), and map('C')
map('A') = 'AAAAA' # replication of A 5 times
map('B') = ''      # B is dropped
map('C') = 'CC'    # replication of C twice
# merging all map results
flat_map('AAAA,,CC') = 'AAAACC'


来源:https://stackoverflow.com/questions/47236465/oversampling-functionality-in-tensorflow-dataset-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!