How to get batch size back from a tensorflow dataset?

痴心易碎 提交于 2020-12-31 06:35:39

问题


It is recommended to use tensorflow dataset as the input pipeline which can be set up as follows:

# Specify dataset
dataset  = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset  = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset  = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()

I should be able to get the batch size (either from dataset itself or from an iterator created from it, i.e. both iterator and next_batch). Maybe someone wants to know how many batches there are in the dataset or its iterators. Or how many batches have been called and how many remain in the iterator? One might also want to get particular elements, or even the entire dataset at once.

I wasn't able to find anything on the tensorflow documentation. Is this possible? If not, does anyone know if this has been requested as an issue on tensorflow GitHub?


回答1:


Try this

import tensorflow as tf
import numpy as np

features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
labels=np.array([[0], [0], [1]], dtype="float32")
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

batch_size = 2
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(np.shape(sess.run(batch_data)[0])[0])
and you will see


回答2:


In TF2, tf.data.Datasets are iterables, so you can get a batch by simply doing:

batch = next(iter(dataset))

and then calculating the batch size is trivial since it becomes the size of the first dimension:

batch_size = batch.shape[0]

So a complete example would look like:

# Specify dataset
dataset  = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset  = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset  = dataset.batch(128)
# Calculate and print batch size
batch_size = next(iter(dataset)).shape[0]
print('Batch size:', batch_size) # prints 128

Or, if you need it as a function:

def calculate_batch_size(dataset):
    return next(iter(dataset)).shape[0]

Note that iterating over a dataset requires eager execution. Moreover, this solution assumes that your dataset is batched, and may get errors if this is not the case. You may also face errors if, after batching, you perform other operations on your dataset that change the shape of its elements.



来源:https://stackoverflow.com/questions/49912441/how-to-get-batch-size-back-from-a-tensorflow-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!