Using string labels in Tensorflow

谁说胖子不能爱 提交于 2019-12-01 04:11:27
mrry

The convert_to_records.py script creates a .tfrecords file in which each record is an Example protocol buffer. That protocol buffer supports string features using the bytes_list kind.

The tf.decode_raw op is used to parse binary strings into image data; it is not designed to parse string (textual) labels. Assuming that features['label'] is a tf.string tensor, you can use the tf.string_to_number op to convert it to a number. There is limited other support for string processing inside your TensorFlow program, so if you need to perform some more complicated function to convert the string label to an integer, you should perform this conversion in Python in the modified version of convert_to_tensor.py.

To add to @mrry 's answer, supposing your string is ascii, you can:

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def write_proto(cls, filepath, ..., item_id): # itemid is an ascii encodable string
    # ...
    with tf.python_io.TFRecordWriter(filepath) as writer:
        example = tf.train.Example(features=tf.train.Features(feature={
             # write it as a bytes array, supposing your string is `ascii`
            'item_id': _bytes_feature(bytes(item_id, encoding='ascii')), # python 3
            # ...
        }))
        writer.write(example.SerializeToString())

Then:

def parse_single_example(cls, example_proto, graph=None):
    features_dict = tf.parse_single_example(example_proto,
        features={'item_id': tf.FixedLenFeature([], tf.string),
        # ...
        })
    # decode as uint8 aka bytes
    instance.item_id = tf.decode_raw(features_dict['item_id'], tf.uint8)

and then when you get it back in your session, transform back to string:

item_id, ... = session.run(your_tfrecords_iterator.get_next())
print(str(item_id.flatten(), 'ascii')) # python 3

I took the uint8 trick from this related so answer. Works for me but comments/improvements welcome.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!