Produce balanced mini batch with Dataset API

后端 未结 1 2019
余生分开走
余生分开走 2020-12-15 11:07

I\'ve a question about the new dataset API (tensorflow 1.4rc1). I\'ve a unbalanced dataset wrt to labels 0 and 1. My goal is to create balanced min

1条回答
  •  轻奢々
    轻奢々 (楼主)
    2020-12-15 11:23

    You are on the right track. The following example uses Dataset.flat_map() to turn each pair of a positive example and a negative example into two consecutive examples in the result:

    dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
    
    # Each input element will be converted into a two-element `Dataset` using
    # `Dataset.from_tensors()` and `Dataset.concatenate()`, then `Dataset.flat_map()`
    # will flatten the resulting `Dataset`s into a single `Dataset`.
    dataset = dataset.flat_map(
        lambda ex_pos, ex_neg: tf.data.Dataset.from_tensors(ex_pos).concatenate(
            tf.data.Dataset.from_tensors(ex_neg)))
    
    dataset = dataset.batch(20)
    

    0 讨论(0)
提交回复
热议问题