In chainer, how to early stop iteration using chainer.training.Trainer?

泄露秘密 提交于 2019-12-08 03:58:11

问题


I am using chainer framework(Deep learning), suppose I have to stop iteration once two iteration's target function value's gap is little: f - old_f < eps. but chainer.training.Trainer's stop_trigger is (args.epoch, 'epoch') tuple. how to trigger early stop?


回答1:


I implemented EarlyStoppingTrigger example according to @Seiya Tokui's answer, based on your situation.

from chainer import reporter
from chainer.training import util

class EarlyStoppingTrigger(object):

"""Early stopping trigger

It observes the value specified by `key`, and invoke a trigger only when 
observing value satisfies the `stop_condition`.
The trigger may be used to `stop_trigger` option of Trainer module for
early stopping the training.
Args:
    max_epoch (int or float): Max epoch for the training, even if the value 
        is not reached to the condition specified by `stop_condition`,
        finish the training if it reaches to `max_epoch` epoch.
    key (str): Key of value to be observe for `stop_condition`.
    stop_condition (callable): To check the previous value and current value
        to decide early stop timing. Default value is `None`, in that case
        internal `_stop_condition` method is used.
    eps (float): It is used by the internal `_stop_condition`.
    trigger: Trigger that decides the comparison interval between previous
        best value and current value. This must be a tuple in the form of
        ``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
        :class:`~chainer.training.triggers.IntervalTrigger`.
"""

def __init__(self, max_epoch, key, stop_condition=None, eps=0.01,
             trigger=(1, 'epoch')):
    self.max_epoch = max_epoch
    self.eps = eps
    self._key = key
    self._current_value = None
    self._interval_trigger = util.get_trigger(trigger)
    self._init_summary()
    self.stop_condition = stop_condition or self._stop_condition

def __call__(self, trainer):
    """Decides whether the extension should be called on this iteration.
    Args:
        trainer (~chainer.training.Trainer): Trainer object that this
            trigger is associated with. The ``observation`` of this trainer
            is used to determine if the trigger should fire.
    Returns:
        bool: ``True`` if the corresponding extension should be invoked in
            this iteration.
    """

    epoch_detail = trainer.updater.epoch_detail
    if self.max_epoch <= epoch_detail:
        print('Reached to max_epoch.')
        return True

    observation = trainer.observation
    summary = self._summary
    key = self._key
    if key in observation:
        summary.add({key: observation[key]})

    if not self._interval_trigger(trainer):
        return False

    stats = summary.compute_mean()
    value = float(stats[key])  # copy to CPU
    self._init_summary()

    if self._current_value is None:
        self._current_value = value
        return False
    else:
        if self.stop_condition(self._current_value, value):
            # print('Previous value {}, Current value {}'
            #       .format(self._current_value, value))
            print('Invoke EarlyStoppingTrigger...')
            self._current_value = value
            return True
        else:
            self._current_value = value
            return False

def _init_summary(self):
    self._summary = reporter.DictSummary()

def _stop_condition(self, current_value, new_value):
    return current_value - new_value < self.eps

Usage: You can pass it to the stop_trigger option of trainer,

early_stop = EarlyStoppingTrigger(args.epoch, key='validation/main/loss', eps=0.01)
trainer = training.Trainer(updater, stop_trigger=early_stop, out=args.out)

See the this gist for whole working example code.

[Note] I noticed that we also need to fix ProgressBar extension to pass training_length explicitly, if we use customized stop_trigger.




回答2:


You can pass a callable object to the stop_trigger option. The callable object is called at each iteration by passing the Trainer object. It should return a boolean value. When the returned value is True, the training is stopped. In order to implement early stopping, you can write your own trigger function and pass it to the stop_trigger option of Trainer.

Other APIs that accept a trigger object also accepts a callable; see the document of get_trigger for details.

Note: a tuple value for stop_trigger is a short hand notation of using chainer.training.triggers.IntervalTrigger as the callable.



来源:https://stackoverflow.com/questions/45891924/in-chainer-how-to-early-stop-iteration-using-chainer-training-trainer

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!