I\'m writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we
A custom training loop is just a normal Python loop, so you can use if statements to break the loop whenever some condition is met. For instance:
if len(loss_history) > patience:
if loss_history.popleft()*delta < min(loss_history):
print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
f'validation loss in the last {patience} epochs.')
break
If there is no improvement of delta% in the loss in the past patience epochs, the loop will be broken. Here, I'm using a collections.deque, which can easily be used as a rolling list that keeps in memory information only the last patience epochs.
Here's a full implementation, with the documentation example from the Tensorflow documentation:
patience = 3
delta = 0.001
loss_history = deque(maxlen=patience + 1)
for epoch in range(1, 25 + 1):
train_loss = tf.metrics.Mean()
train_acc = tf.metrics.CategoricalAccuracy()
test_loss = tf.metrics.Mean()
test_acc = tf.metrics.CategoricalAccuracy()
for x, y in train:
loss_value, grads = get_grad(model, x, y)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss.update_state(loss_value)
train_acc.update_state(y, model(x, training=True))
for x, y in test:
loss_value, _ = get_grad(model, x, y)
test_loss.update_state(loss_value)
test_acc.update_state(y, model(x, training=False))
print(verbose.format(epoch,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()))
loss_history.append(test_loss.result())
if len(loss_history) > patience:
if loss_history.popleft()*delta < min(loss_history):
print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
f'validation loss in the last {patience} epochs.')
break
Epoch 1 Loss: 0.191 TLoss: 0.282 Acc: 68.920% TAcc: 89.200%
Epoch 2 Loss: 0.157 TLoss: 0.297 Acc: 70.880% TAcc: 90.000%
Epoch 3 Loss: 0.133 TLoss: 0.318 Acc: 71.560% TAcc: 90.800%
Epoch 4 Loss: 0.117 TLoss: 0.299 Acc: 71.960% TAcc: 90.800%
Early stopping. No improvement of more than 0.10000% in validation loss in the last 3 epochs.