TF.data.dataset.map(map_func) with Eager Mode

徘徊边缘 提交于 2020-08-07 04:41:45

问题


I am using TF 1.8 with eager mode enabled.

I cannot print the example inside the mapfunc. It when I run tf.executing_eagerly() from within the mapfunc I get "False"

import os
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

tfe = tf.contrib.eager
tf.enable_eager_execution()
x = tf.random_uniform([16,10], -10, 0, tf.int64)
print(x)
DS = tf.data.Dataset.from_tensor_slices((x))


def mapfunc(ex, con):
    import pdb; pdb.set_trace()
    new_ex = ex + con
    print(new_ex) 
    return new_ex

DS = DS.map(lambda x: mapfunc(x, [7]))
DS = DS.make_one_shot_iterator()

print(DS.next())

print(new_ex) outputs:

Tensor("add:0", shape=(10,), dtype=int64)

Outside mapfunc, it works fine. But inside it, the passed example does not have a value, nor .numpy() attribute.


回答1:


The tf.data transformations actually execute as a graph, so the body of the map function itself isn't executed eagerly. See #14732 for some more discussion on this.

If you really need eager execution for the map function, you could use tf.contrib.eager.py_func, so something like:

DS = DS.map(lambda x: tf.contrib.eager.py_func(
  mapfunc,
  [x, tf.constant(7, dtype=tf.int64)], tf.int64)
# In TF 1.9+, the next line can be print(next(DS))
print(DS.make_one_shot_iterator().next())

Hope that helps.

Note that by adding a py_func to the dataset, the single-threaded Python interpreter will be in the loop for every element produced.



来源:https://stackoverflow.com/questions/50538038/tf-data-dataset-mapmap-func-with-eager-mode

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!