Java - train loaded tensorflow model

假如想象 提交于 2019-12-08 13:50:51

问题


Does anyone know if it is possible after a model is loaded into Java from Tensorflow Python to continue training the model? I've come up with this snippet of code, but did not work (yes, the output is the same as the input)

for(int i = 0; i < 10000; i++) {
    Tensor cost = b.session().runner().feed("input", input).feed("output", input).fetch("cost").run().get(0);
    System.out.println(result1);
}

This is what is printed out 10000 times:

FLOAT tensor with shape []

And after all, the predictions are the same as they were before.

Moreover, if it is possible to continue training the loaded model, is it possible to update the saved model's weights and biases?


回答1:


You are feeding inputs and fetching the loss; this won't train the model. To do so you'll need to feed batches of data and run the update ops (returned maybe from optimizer.minimize).

It is possible to do this from Java, but the infrastructure in python is more well-developed, including threads to prefetch input data in queues, monitoring when the input is over, saving summaries, and doing distributed training.



来源:https://stackoverflow.com/questions/43605690/java-train-loaded-tensorflow-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!