Using training made with python API as input to LabelImage module in java API?

后端 未结 2 852
萌比男神i
萌比男神i 2021-01-07 09:56

I have a problem with java tensorflow API. I have run the training using the python tensorflow API, generating the files output_graph.pb and output_labels.txt. Now for some

相关标签:
2条回答
  • 2021-01-07 10:27

    The model used by default in LabelImage.java is different that the model that is being retrained, so the names of inputs and output nodes do not align. Note that TensorFlow models are graphs and the arguments to feed() and fetch() are names of nodes in the graph. So you need to know the names appropriate for your model.

    Looking at retrain.py, it seems that it has a node that takes the raw contents of a JPEG file as input (the node DecodeJpeg/contents) and produces the set of labels in the node final_result.

    If that's the case, then you'd do something like the following in Java (and you don't need the bit that constructs a graph to normalize the image since that seems to be a part of the retrained model, so replace LabelImage.java:64 with something like:

    try (Tensor image = Tensor.create(imageBytes);
         Graph g = new Graph()) {
      g.importGraphDef(graphDef);
      try (Session s = new Session(g);
        // Note the change to the name of the node and the fact
        // that it is being provided the raw imageBytes as input
        Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
        final long[] rshape = result.shape();
        if (result.numDimensions() != 2 || rshape[0] != 1) {
          throw new RuntimeException(
              String.format(
                  "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                  Arrays.toString(rshape)));
        }
        int nlabels = (int) rshape[1];
        float[] probabilities = result.copyTo(new float[1][nlabels])[0];
        // At this point nlabels = number of classes in your retrained model
        DoSomethingWith(probabilities);
      }
    }
    

    Hope that helps.

    0 讨论(0)
  • 2021-01-07 10:31

    Regarding the "No operation" error, I was able to resolve that by using input and output layer names "Mul" and "final_result", respectively. See:

    https://github.com/tensorflow/tensorflow/issues/2883

    0 讨论(0)
提交回复
热议问题