Skip to content

Commit

Permalink
Check early stopping metric in the beginning of the training (#542)
Browse files Browse the repository at this point in the history
* Early stopping fix

* Update trainer.py

* Don't force validation sanity check

* fix tests

* update

* Added early_stopping check_metrics

* Updated docs

* Update docs

* Do not call early stopping when validation is disabled

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
kuynzereb and williamFalcon committed Jan 23, 2020
1 parent 588ad83 commit 50881c0
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 22 deletions.
45 changes: 30 additions & 15 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,23 @@ class EarlyStopping(Callback):
Stop training when a monitored quantity has stopped improving.
Args:
monitor (str): quantity to be monitored.
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
after which training will be stopped.
verbose (bool): verbosity mode.
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``0``.
mode (str): one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
Expand All @@ -97,18 +99,20 @@ class EarlyStopping(Callback):
"""

def __init__(self, monitor='val_loss',
min_delta=0.0, patience=0, verbose=0, mode='auto'):
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super(EarlyStopping, self).__init__()

self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0

if mode not in ['auto', 'min', 'max']:
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
if self.verbose > 0:
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'

if mode == 'min':
Expand All @@ -128,23 +132,34 @@ def __init__(self, monitor='val_loss',

self.on_train_begin()

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')

if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
elif self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)

return False

return True

def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
stop_training = False
if current is None:
warnings.warn(
f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
RuntimeWarning)
stop_training = True
if not self.check_metrics(logs):
return stop_training

current = logs.get(self.monitor)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,20 @@ def configure_early_stopping(self, early_stop_callback, logger):
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif early_stop_callback is None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self,
logger=True,
checkpoint_callback=True,
early_stop_callback=True,
early_stop_callback=None,
default_save_path=None,
gradient_clip_val=0,
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -121,15 +121,22 @@ def __init__(
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If
set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
Will raise an error if ``'val_loss'`` is not found.
If set to ``False``, then early stopping will be disabled.
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
If ``'val_loss'`` is not found will work as if early stopping is disabled.
Default: ``None``.
Example::
from pytorch_lightning.callbacks import EarlyStopping
# default used by the Trainer
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=True,
strict=False,
verbose=False,
mode='min'
)
Expand Down Expand Up @@ -809,12 +816,17 @@ def run_pretrain_routine(self, model):
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)

self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
eval_results = self.evaluate(model, self.get_val_dataloaders(),
self.num_sanity_val_steps, False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()

if self.enable_early_stop:
self.early_stop_callback.check_metrics(callback_metrics)

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def train(self):

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
if (self.enable_early_stop and not self.disable_validation and
(met_min_epochs or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
logs=self.callback_metrics)
# stop training
Expand Down Expand Up @@ -401,6 +402,9 @@ def run_training_epoch(self):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase):
val_percent_check=0.2,
test_percent_check=0.2,
checkpoint_callback=checkpoint,
logger=logger
logger=logger,
early_stop_callback=False
)

# fit model
Expand Down Expand Up @@ -318,6 +319,7 @@ def train_dataloader(self):
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
weights_summary=None,
early_stop_callback=False
)

hparams = tutils.get_hparams()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class CurrentTestModel(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2,
train_percent_check=0.2
)

# fit model
Expand Down

0 comments on commit 50881c0

Please sign in to comment.