TensorFlow tf.data.Dataset and bucketing

风格不统一 提交于 2019-12-05 02:36:23

Various bucketing use cases using Dataset API are explained well here.

bucket_by_sequence_length() example:

def elements_gen():
   text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
   label = [1, 2, 1, 2]
   for x, y in zip(text, label):
       yield (x, y)

def element_length_fn(x, y):
   return tf.shape(x)[0]

dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                     output_shapes=([None],[]),
                                     output_types=(tf.int32, tf.int32))

dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                              bucket_batch_sizes=[2, 2, 2],
                                                              bucket_boundaries=[0, 8]))

batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:

   for _ in range(2):
      print('Get_next:')
      print(sess.run(batch))

Output:

Get_next:
(array([[1, 2, 3, 0, 0],
   [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
Get_next:
(array([[1, 2, 0, 0],
   [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!