Tensorflow - String processing in Dataset API

谁都会走 提交于 2019-12-05 17:47:37

The object lines_split.values[0] is a tf.Tensor object representing the 0th split from line. It is not a Python string, and so it does not have a .strip() or .lower() method. Instead you will have to apply TensorFlow operations to the tensor to perform the conversion.

TensorFlow currently doesn't have very many string operations, but you can use the tf.py_func() op to run some Python code on a tf.Tensor:

def _parse_data(line):
    line_split = tf.string_split([line], '\t')

    raw_text = tf.py_func(
        lambda x: x.strip().lower(), line_split.values[0], tf.string)

    label = tf.string_to_number(line_split.values[1], out_type=tf.int32)

    return {"raw_text": raw_text, "label": label}

Note that there are a couple of other problems with the code in the question:

  • Don't use tf.parse_single_example(). This op is only used for parsing tf.train.Example protocol buffer strings; you do not need to use it when parsing text, and you can return the extracted features directly from _parse_data().
  • Use dataset.map() instead of dataset.flat_map(). You only need to use flat_map() when the result of your mapping function is a Dataset object (and hence the return values need to be flattened into a single dataset). You must use map() when the result is one or more tf.Tensor objects.
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!