Using string labels in Tensorflow

后端 未结 2 1390
野的像风
野的像风 2021-01-12 03:50

i\'am still trying to run Tensorflow with own image data. I was able to create a .tfrecords file with the conevert_to() function from this example link

Now i i\'d l

2条回答
  •  不要未来只要你来
    2021-01-12 04:19

    To add to @mrry 's answer, supposing your string is ascii, you can:

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def write_proto(cls, filepath, ..., item_id): # itemid is an ascii encodable string
        # ...
        with tf.python_io.TFRecordWriter(filepath) as writer:
            example = tf.train.Example(features=tf.train.Features(feature={
                 # write it as a bytes array, supposing your string is `ascii`
                'item_id': _bytes_feature(bytes(item_id, encoding='ascii')), # python 3
                # ...
            }))
            writer.write(example.SerializeToString())
    

    Then:

    def parse_single_example(cls, example_proto, graph=None):
        features_dict = tf.parse_single_example(example_proto,
            features={'item_id': tf.FixedLenFeature([], tf.string),
            # ...
            })
        # decode as uint8 aka bytes
        instance.item_id = tf.decode_raw(features_dict['item_id'], tf.uint8)
    

    and then when you get it back in your session, transform back to string:

    item_id, ... = session.run(your_tfrecords_iterator.get_next())
    print(str(item_id.flatten(), 'ascii')) # python 3
    

    I took the uint8 trick from this related so answer. Works for me but comments/improvements welcome.

提交回复
热议问题