How train_on_batch()
is different from fit()
? What are the cases when we should use train_on_batch()
?
train_on_batch()
gives you greater control of the state of the LSTM, for example, when using a stateful LSTM and controlling calls to model.reset_states()
is needed. You may have multi-series data and need to reset the state after each series, which you can do with train_on_batch()
, but if you used .fit()
then the network would be trained on all the series of data without resetting the state. There's no right or wrong, it depends on what data you're using, and how you want the network to behave.