Using training made with python API as input to LabelImage module in java API?

北战南征 提交于 2019-12-01 01:11:36

The model used by default in LabelImage.java is different that the model that is being retrained, so the names of inputs and output nodes do not align. Note that TensorFlow models are graphs and the arguments to feed() and fetch() are names of nodes in the graph. So you need to know the names appropriate for your model.

Looking at retrain.py, it seems that it has a node that takes the raw contents of a JPEG file as input (the node DecodeJpeg/contents) and produces the set of labels in the node final_result.

If that's the case, then you'd do something like the following in Java (and you don't need the bit that constructs a graph to normalize the image since that seems to be a part of the retrained model, so replace LabelImage.java:64 with something like:

try (Tensor image = Tensor.create(imageBytes);
     Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
    // Note the change to the name of the node and the fact
    // that it is being provided the raw imageBytes as input
    Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
      throw new RuntimeException(
          String.format(
              "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
              Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float[] probabilities = result.copyTo(new float[1][nlabels])[0];
    // At this point nlabels = number of classes in your retrained model
    DoSomethingWith(probabilities);
  }
}

Hope that helps.

Regarding the "No operation" error, I was able to resolve that by using input and output layer names "Mul" and "final_result", respectively. See:

https://github.com/tensorflow/tensorflow/issues/2883

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!