Using training made with python API as input to LabelImage module in java API?

♀尐吖头ヾ 提交于 2019-12-03 22:44:43

问题


I have a problem with java tensorflow API. I have run the training using the python tensorflow API, generating the files output_graph.pb and output_labels.txt. Now for some reason I want to use those files as input to the LabelImage module in java tensorflow API. I thought everything would have worked fine since that module wants exactly one .pb and one .txt. Nevertheless, when I run the module, I get this error:

2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)

I would be very grateful if you help me finding where the problem is. Furthermore I want to ask you if there is a way to run the training from java tensorflow API, because that would make things easier.

To be more precise:

As a matter of fact, I do not use self-written code, at least for the relevant steps. All I have done is doing the training with this module, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py, feeding it with the directory that contains the images divided among subdirectories according to their description. In particular, I think these are the lines that generate the outputs:

output_graph_def = graph_util.convert_variables_to_constants(
    sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
  f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
  f.write('\n'.join(image_lists.keys()) + '\n')

Then, I give the outputs (one some_graph.pb and one some_labels.txt) as input to this java module: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java, replacing the default inputs. The error I get is the one reported above.


回答1:


The model used by default in LabelImage.java is different that the model that is being retrained, so the names of inputs and output nodes do not align. Note that TensorFlow models are graphs and the arguments to feed() and fetch() are names of nodes in the graph. So you need to know the names appropriate for your model.

Looking at retrain.py, it seems that it has a node that takes the raw contents of a JPEG file as input (the node DecodeJpeg/contents) and produces the set of labels in the node final_result.

If that's the case, then you'd do something like the following in Java (and you don't need the bit that constructs a graph to normalize the image since that seems to be a part of the retrained model, so replace LabelImage.java:64 with something like:

try (Tensor image = Tensor.create(imageBytes);
     Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
    // Note the change to the name of the node and the fact
    // that it is being provided the raw imageBytes as input
    Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
      throw new RuntimeException(
          String.format(
              "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
              Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float[] probabilities = result.copyTo(new float[1][nlabels])[0];
    // At this point nlabels = number of classes in your retrained model
    DoSomethingWith(probabilities);
  }
}

Hope that helps.




回答2:


Regarding the "No operation" error, I was able to resolve that by using input and output layer names "Mul" and "final_result", respectively. See:

https://github.com/tensorflow/tensorflow/issues/2883



来源:https://stackoverflow.com/questions/43628829/using-training-made-with-python-api-as-input-to-labelimage-module-in-java-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!