TensorFlow - tf.data.Dataset reading large HDF5 files

混江龙づ霸主 提交于 2019-11-30 04:50:51

I stumbled across this question while dealing with a similar issue. I came up with a solution based on using a Python generator, together with the TF dataset construction method from_generator. Because we use a generator, the HDF5 file should be opened for reading only once and kept open as long as there are entries to read. So it will not be opened, read, and then closed for every single call to get the next data element.

Generator definition

To allow the user to pass in the HDF5 filename as an argument, I generated a class that has a __call__ method since from_generator specifies that the generator has to be callable. This is the generator:

import h5py
import tensorflow as tf

class generator:
    def __init__(self, file):
        self.file = file

    def __call__(self):
        with h5py.File(self.file, 'r') as hf:
            for im in hf["train_img"]:
                yield im

By using a generator, the code should pick up from where it left off at each call from the last time it returned a result, instead of running everything from the beginning again. In this case it is on the next iteration of the inner for loop. So this should skip opening the file again for reading, keeping it open as long as there is data to yield. For more on generators, see this excellent Q&A.

Of course, you will have to replace anything inside the with block to match how your dataset is constructed and what outputs you want to obtain.

Usage example

ds = tf.data.Dataset.from_generator(
    generator(hdf5_path), 
    tf.uint8, 
    tf.TensorShape([427,561,3]))

value = ds.make_one_shot_iterator().get_next()

# Example on how to read elements
while True:
    try:
        data = sess.run(value)
        print(data.shape)
    except tf.errors.OutOfRangeError:
        print('done.')
        break

Again, in my case I had stored uint8 images of height 427, width 561, and 3 color channels in my dataset, so you will need to modify these in the above call to match your use case.

Handling multiple files

I have a proposed solution for handling multiple HDF5 files. The basic idea is to construct a Dataset from the filenames as usual, and then use the interleave method to process many input files concurrently, getting samples from each of them to form a batch, for example.

The idea is as follows:

ds = tf.data.Dataset.from_tensor_slices(filenames)
# You might want to shuffle() the filenames here depending on the application
ds = ds.interleave(lambda filename: tf.data.Dataset.from_generator(
        generator(filename), 
        tf.uint8, 
        tf.TensorShape([427,561,3])),
       cycle_length, block_length)

What this does is open cycle_length files concurrently, and produce block_length items from each before moving to the next file - see interleave documentation for details. You can set the values here to match what is appropriate for your application: e.g., do you need to process one file at a time or several concurrently, do you only want to have a single sample at a time from each file, and so on.

Edit: for a parallel version, take a look at tf.contrib.data.parallel_interleave!

Possible caveats

Be aware of the peculiarities of using from_generator if you decide to go with the solution. For Tensorflow 1.6.0, the documentation of from_generator mentions these two notes.

It may be challenging to apply this across different environments or with distributed training:

NOTE: The current implementation of Dataset.from_generator() uses tf.py_func and inherits the same constraints. In particular, it requires the Dataset- and Iterator-related operations to be placed on a device in the same process as the Python program that called Dataset.from_generator(). The body of generator will not be serialized in a GraphDef, and you should not use this method if you need to serialize your model and restore it in a different environment.

Be careful if the generator depends on external state:

NOTE: If generator depends on mutable global variables or other external state, be aware that the runtime may invoke generator multiple times (in order to support repeating the Dataset) and at any time between the call to Dataset.from_generator() and the production of the first element from the generator. Mutating global variables or external state can cause undefined behavior, and we recommend that you explicitly cache any external state in generator before calling Dataset.from_generator().

I took me a while to figure this out, so I thought I should record this here. Based on mikkola's answer, this is how to handle multiple files:

import h5py
import tensorflow as tf

class generator:
    def __call__(self, file):
        with h5py.File(file, 'r') as hf:
            for im in hf["train_img"]:
                yield im

ds = tf.data.Dataset.from_tensor_slices(filenames)
ds = ds.interleave(lambda filename: tf.data.Dataset.from_generator(
        generator(), 
        tf.uint8, 
        tf.TensorShape([427,561,3]),
        args=(filename,)),
       cycle_length, block_length)

The key is you can't pass filename directly to generator, since it's a Tensor. You have to pass it through args, which tensorflow evaluates and converts it to a regular python variable.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!