Applying callbacks in a custom training loop in Tensorflow 2.0

前端 未结 3 491
伪装坚强ぢ
伪装坚强ぢ 2021-01-13 05:00

I\'m writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we

3条回答
  •  谎友^
    谎友^ (楼主)
    2021-01-13 05:36

    A custom training loop is just a normal Python loop, so you can use if statements to break the loop whenever some condition is met. For instance:

    if len(loss_history) > patience:
        if loss_history.popleft()*delta < min(loss_history):
            print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
                  f'validation loss in the last {patience} epochs.')
            break
    

    If there is no improvement of delta% in the loss in the past patience epochs, the loop will be broken. Here, I'm using a collections.deque, which can easily be used as a rolling list that keeps in memory information only the last patience epochs.

    Here's a full implementation, with the documentation example from the Tensorflow documentation:

    patience = 3
    delta = 0.001
    
    loss_history = deque(maxlen=patience + 1)
    
    for epoch in range(1, 25 + 1):
        train_loss = tf.metrics.Mean()
        train_acc = tf.metrics.CategoricalAccuracy()
        test_loss = tf.metrics.Mean()
        test_acc = tf.metrics.CategoricalAccuracy()
    
        for x, y in train:
            loss_value, grads = get_grad(model, x, y)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            train_loss.update_state(loss_value)
            train_acc.update_state(y, model(x, training=True))
    
        for x, y in test:
            loss_value, _ = get_grad(model, x, y)
            test_loss.update_state(loss_value)
            test_acc.update_state(y, model(x, training=False))
    
        print(verbose.format(epoch,
                             train_loss.result(),
                             test_loss.result(),
                             train_acc.result(),
                             test_acc.result()))
    
        loss_history.append(test_loss.result())
    
        if len(loss_history) > patience:
            if loss_history.popleft()*delta < min(loss_history):
                print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
                      f'validation loss in the last {patience} epochs.')
                break
    
    Epoch  1 Loss: 0.191 TLoss: 0.282 Acc: 68.920% TAcc: 89.200%
    Epoch  2 Loss: 0.157 TLoss: 0.297 Acc: 70.880% TAcc: 90.000%
    Epoch  3 Loss: 0.133 TLoss: 0.318 Acc: 71.560% TAcc: 90.800%
    Epoch  4 Loss: 0.117 TLoss: 0.299 Acc: 71.960% TAcc: 90.800%
    
    Early stopping. No improvement of more than 0.10000% in validation loss in the last 3 epochs.
    

提交回复
热议问题