What does train_on_batch() do in keras model?

杀马特。学长 韩版系。学妹 提交于 2020-01-24 06:42:06

问题


I saw a sample of code (too big to paste here) where the author used model.train_on_batch(in, out) instead of model.fit(in, out). The official documentation of Keras says:

Single gradient update over one batch of samples.

But I don't get it. Is it the same as fit(), but instead of doing many feed-forward and backprop steps, it does it once? Or am I wrong?


回答1:


Yes, train_on_batch trains using a single batch only and once.

While fit trains many batches for many epochs. (Each batch causes an update in weights).

The idea of using train_on_batch is probably to do more things yourself between each batch.




回答2:


It is used when we want to understand and do some custom changes after each batch training.

A more precide use case is with the GANs. You have to update discriminator but during update the GAN network you have to keep the discriminator untrainable. so you first train the discriminator and then train the gan keeping discriminator untrainable. see this for more understanding: https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3



来源:https://stackoverflow.com/questions/48550201/what-does-train-on-batch-do-in-keras-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!