Feeding image data in tensorflow for transfer learning

后端 未结 2 1275
北荒
北荒 2020-12-02 23:31

I am trying to use tensorflow for transfer learning. I downloaded the pre-trained model inception3 from the tutorial. In the code, for prediction:

predictio         


        
相关标签:
2条回答
  • 2020-12-03 00:09

    The following code should handle of both cases.

    import numpy as np
    from PIL import Image
    
    image_file = 'test.jpeg'
    with tf.Session() as sess:
    
        #     softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
        if image_file.lower().endswith('.jpeg'):
            image_data = tf.gfile.FastGFile(image_file, 'rb').read()
            prediction = sess.run('final_result:0', {'DecodeJpeg/contents:0': image_data})
        elif image_file.lower().endswith('.png'):
            image = Image.open(image_file)
            image_array = np.array(image)[:, :, 0:3]
            prediction = sess.run('final_result:0', {'DecodeJpeg:0': image_array})
    
        prediction = prediction[0]    
        print(prediction)
    

    or shorter version with direct strings:

    image_file = 'test.png' # or 'test.jpeg'
    image_data = tf.gfile.FastGFile(image_file, 'rb').read()
    ph = tf.placeholder(tf.string, shape=[])
    
    with tf.Session() as sess:        
        predictions = sess.run(output_layer_name, {ph: image_data} )
    
    0 讨论(0)
  • 2020-12-03 00:21

    The shipped InceptionV3 graph used in classify_image.py only supports JPEG images out-of-the-box. There are two ways you could use this graph with PNG images:

    1. Convert the PNG image to a height x width x 3 (channels) Numpy array, for example using PIL, then feed the 'DecodeJpeg:0' tensor:

      import numpy as np
      from PIL import Image
      # ...
      
      image = Image.open("example.png")
      image_array = np.array(image)[:, :, 0:3]  # Select RGB channels only.
      
      prediction = sess.run(softmax_tensor, {'DecodeJpeg:0': image_array})
      

      Perhaps confusingly, 'DecodeJpeg:0' is the output of the DecodeJpeg op, so by feeding this tensor, you are able to feed raw image data.

    2. Add a tf.image.decode_png() op to the imported graph. Simply switching the name of the fed tensor from 'DecodeJpeg/contents:0' to 'DecodePng/contents:0' does not work because there is no 'DecodePng' op in the shipped graph. You can add such a node to the graph by using the input_map argument to tf.import_graph_def():

      png_data = tf.placeholder(tf.string, shape=[])
      decoded_png = tf.image.decode_png(png_data, channels=3)
      # ...
      
      graph_def = ...
      softmax_tensor = tf.import_graph_def(
          graph_def,
          input_map={'DecodeJpeg:0': decoded_png},
          return_elements=['softmax:0'])
      
      sess.run(softmax_tensor, {png_data: ...})
      
    0 讨论(0)
提交回复
热议问题