TensorFlow: `tf.data.Dataset.from_generator()` does not work with strings on Python 3.x

谁说我不能喝 提交于 2019-12-06 04:29:13

问题


I need to iterate through large number of image files and feed the data to tensorflow. I created a Dataset back by a generator function that produces the file path names as strings and then transform the string path to image data using map. But it failed as generating string values won't work, as shown below. Is there a fix or work around for this?

2017-12-07 15:29:05.820708: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
producing data/miniImagenet/val/n01855672/n0185567200001000.jpg
2017-12-07 15:29:06.009141: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
2017-12-07 15:29:06.009215: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
    return fn(*args)
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
    status, run_metadata)
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type str
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,21168]], output_types=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

The test codes are shown below. It can work correctly with from_tensor_slices or by first putting the the file name list in a tensor. however, either work around would exhaust GPU memory.

import tensorflow as tf

if __name__ == "__main__":
    file_names = ['data/miniImagenet/val/n01855672/n0185567200001000.jpg',
                  'data/miniImagenet/val/n01855672/n0185567200001005.jpg']
    # note: converting the file list to tensor and returning an index from generator works
    # path_to_indexes = {p: i for i, p in enumerate(file_names)}
    # file_names_tensor = tf.convert_to_tensor(file_names)

    def dataset_producer():
        for s in file_names:
            print('producing', s)
            yield s
    dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
                                             output_shapes=(tf.TensorShape([])))

    # note: this would also work
    # dataset = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(file_names))

    def read_image(filename):
        # filename = file_names_tensor[filename_index]
        image_file = tf.read_file(filename, name='read_file')
        image = tf.image.decode_jpeg(image_file, channels=3)
        image.set_shape((84,84,3))
        image = tf.reshape(image, [21168])
        image = tf.cast(image, tf.float32) / 255.0
        return image

    dataset = dataset.map(read_image)
    dataset = dataset.batch(2)
    data_iterator = dataset.make_one_shot_iterator()
    images = data_iterator.get_next()
    print('images', images)
    max_value = tf.argmax(images)
    with tf.Session() as session:
        result = session.run(max_value)
        print(result)

回答1:


This is a bug affecting Python 3.x that was fixed after the TensorFlow 1.4 release. All releases of TensorFlow from 1.5 onwards contain the fix.

If you just use an earlier version, the workaround is to convert the strings to bytes before returning them from the generator. The following code should work:

def dataset_producer():
    for s in file_names:
        print('producing', s)
        yield s.encode('utf-8')  # Convert `s` to `bytes`.

dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
                                         output_shapes=(tf.TensorShape([])))


来源:https://stackoverflow.com/questions/47705684/tensorflow-tf-data-dataset-from-generator-does-not-work-with-strings-on-pyt

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!