tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

。_饼干妹妹 提交于 2019-12-01 16:00:52

tf.data.Dataset.list_files creates a tensor called MatchingFiles:0 (with the appropriate prefix if applicable).

You could evaluate

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

to get the number of files.

Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.

In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.

To do this, you can watch the number of epochs that is counted by your Dataset. repeat() creates a member called _count, that counts the number of epochs. By observing it during your iterations, you can spot when it changes and compute your dataset size from there.

This counter may be buried in the hierarchy of Datasets that is created when calling member functions successively, so we have to dig it out like this.

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

Note that with this technique, the computation of your dataset size is not exact, because the batch during which epoch_counter is incremented typically mixes samples from two successive epochs. So this computation is precise up to your batch length.

markemus

len(list(dataset)) works in eager mode, although that's obviously not a good general solution.

Unfortunately, I don't believe there is a feature like that yet in TF. With TF 2.0 and eager execution however, you could just iterate over the dataset:

num_elements = 0
for element in dataset:
    num_elements += 1

This is the most storage efficient way I could come up with

This really feels like a feature that should have been added a long time ago. Fingers crossed they add this a length feature in a later version.

Take a look here: https://github.com/tensorflow/tensorflow/issues/26966

It doesn't work for TFRecord datasets, but it works fine for other types.

TL;DR:

num_elements = tf.data.experimental.cardinality(dataset).numpy()

Albert James Teddy

The following code works in TF2:

data._tensors[0].shape[0]

data has to be an unbatched dataset for this to work

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!