问题
I am using chainer framework(Deep learning), suppose I have to stop iteration once two iteration's target function value's gap is little: f - old_f < eps. but chainer.training.Trainer's stop_trigger is (args.epoch, 'epoch') tuple. how to trigger early stop?
回答1:
I implemented EarlyStoppingTrigger example according to @Seiya Tokui's answer, based on your situation.
from chainer import reporter
from chainer.training import util
class EarlyStoppingTrigger(object):
"""Early stopping trigger
It observes the value specified by `key`, and invoke a trigger only when
observing value satisfies the `stop_condition`.
The trigger may be used to `stop_trigger` option of Trainer module for
early stopping the training.
Args:
max_epoch (int or float): Max epoch for the training, even if the value
is not reached to the condition specified by `stop_condition`,
finish the training if it reaches to `max_epoch` epoch.
key (str): Key of value to be observe for `stop_condition`.
stop_condition (callable): To check the previous value and current value
to decide early stop timing. Default value is `None`, in that case
internal `_stop_condition` method is used.
eps (float): It is used by the internal `_stop_condition`.
trigger: Trigger that decides the comparison interval between previous
best value and current value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
def __init__(self, max_epoch, key, stop_condition=None, eps=0.01,
trigger=(1, 'epoch')):
self.max_epoch = max_epoch
self.eps = eps
self._key = key
self._current_value = None
self._interval_trigger = util.get_trigger(trigger)
self._init_summary()
self.stop_condition = stop_condition or self._stop_condition
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
Args:
trainer (~chainer.training.Trainer): Trainer object that this
trigger is associated with. The ``observation`` of this trainer
is used to determine if the trigger should fire.
Returns:
bool: ``True`` if the corresponding extension should be invoked in
this iteration.
"""
epoch_detail = trainer.updater.epoch_detail
if self.max_epoch <= epoch_detail:
print('Reached to max_epoch.')
return True
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._current_value is None:
self._current_value = value
return False
else:
if self.stop_condition(self._current_value, value):
# print('Previous value {}, Current value {}'
# .format(self._current_value, value))
print('Invoke EarlyStoppingTrigger...')
self._current_value = value
return True
else:
self._current_value = value
return False
def _init_summary(self):
self._summary = reporter.DictSummary()
def _stop_condition(self, current_value, new_value):
return current_value - new_value < self.eps
Usage: You can pass it to the stop_trigger option of trainer,
early_stop = EarlyStoppingTrigger(args.epoch, key='validation/main/loss', eps=0.01)
trainer = training.Trainer(updater, stop_trigger=early_stop, out=args.out)
See the this gist for whole working example code.
[Note] I noticed that we also need to fix ProgressBar extension to pass training_length explicitly, if we use customized stop_trigger.
回答2:
You can pass a callable object to the stop_trigger option. The callable object is called at each iteration by passing the Trainer object. It should return a boolean value. When the returned value is True, the training is stopped. In order to implement early stopping, you can write your own trigger function and pass it to the stop_trigger option of Trainer.
Other APIs that accept a trigger object also accepts a callable; see the document of get_trigger for details.
Note: a tuple value for stop_trigger is a short hand notation of using chainer.training.triggers.IntervalTrigger as the callable.
来源:https://stackoverflow.com/questions/45891924/in-chainer-how-to-early-stop-iteration-using-chainer-training-trainer