How can I use values read from TFRecords as arguments to tf.reshape?

后端 未结 2 1878
谎友^
谎友^ 2021-01-05 00:41
def read_and_decode(filename_queue):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  fe         


        
2条回答
  •  离开以前
    2021-01-05 00:58

    I have faced same issue. According to the Tensorflow documentation, you will be encountered this situation if you are trying to use shuffle_batch operation after reading required data.

    Like in this example, if you don't use shuffle_batch processing, you can load dynamic dimensional files.

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                 'clip_height': tf.FixedLenFeature([], tf.int64),
                 'clip_width': tf.FixedLenFeature([], tf.int64),
                 'clip_raw': tf.FixedLenFeature([], tf.string),
                 'clip_label_raw': tf.FixedLenFeature([], tf.int64)
            })
        image = tf.decode_raw(features['clip_raw'], tf.float64)
        label = tf.cast(features['clip_label_raw'], tf.int32)
        height = tf.cast(features['clip_height'], tf.int32)
        width = tf.cast(features['clip_width'], tf.int32)
        im_shape = tf.stack([height, width, -1])
        new_image = tf.reshape(image, im_shape )
    

    But if you are to use shuffle batch processing, you can't use tf.stack. You have to define dimensions statically similar to this.

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                 'clip_height': tf.FixedLenFeature([], tf.int64),
                 'clip_width': tf.FixedLenFeature([], tf.int64),
                 'clip_raw': tf.FixedLenFeature([], tf.string),
                 'clip_label_raw': tf.FixedLenFeature([], tf.int64)
            })
        image = tf.decode_raw(features['clip_raw'], tf.float64)
        label = tf.cast(features['clip_label_raw'], tf.int32)
        height = tf.cast(features['clip_height'], tf.int32)
        width = tf.cast(features['clip_width'], tf.int32)
        image = tf.reshape(image, [1, 512, 1])
        images, sparse_labels = tf.train.shuffle_batch(
                [image, label], batch_size=batch_size, num_threads=2,
                capacity=1000 + 3 * batch_size,
                min_after_dequeue=100)
    

    @mrry please correct me if I am wrong.

提交回复
热议问题