Flatten a dataset in TensorFlow

为君一笑 提交于 2019-12-04 04:10:10

问题


I am trying to convert a dataset in TensorFlow to have several single-valued tensors. The dataset currently looks like this:

[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...

After the transformation it should look like this:

[12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ...

My initial idea was using flat_map on the data set and then converting each tensor to a list of tensors using reshape and unstack:

output_labels = self.dataset.flat_map(convert_labels)

...

def convert_labels(tensor):
    id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
    return tf.data.Dataset.from_tensors(id_list)

However the shape of each tensor is only partially known (i.e. (?, 1)) which is why the unstack operation fails. Is there any way to still "concat" the different tensors without explicitly iterating over them?


回答1:


Your solution is very close, but Dataset.flat_map() takes a function that returns a tf.data.Dataset object, rather than a list of tensors. Fortunately, the Dataset.from_tensor_slices() method works for exactly your use case, because it can split a tensor into a variable number of elements:

output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)

Note that the tf.contrib.data.unbatch() transformation implements the same functionality, and has a slightly more efficient implementation in the current master branch of TensorFlow (will be included in the 1.9 release):

output_labels = self.dataset.apply(tf.contrib.data.unbatch())


来源:https://stackoverflow.com/questions/49960875/flatten-a-dataset-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!