Skip to content

Commit

Permalink
Revert "Skip tuner algorithms on fast dev (#3903)"
Browse files Browse the repository at this point in the history
This reverts commit 189ed25
  • Loading branch information
SeanNaren committed Nov 11, 2020
1 parent 4c61f70 commit 7625e49
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 43 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458))


- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))


- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
Expand All @@ -41,7 +41,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))


### Deprecated

Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def scale_batch_size(trainer,
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
return

if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
Expand Down
19 changes: 3 additions & 16 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
Expand All @@ -43,10 +41,6 @@
def _run_lr_finder_internally(trainer, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_finder = lr_find(trainer, model)

if lr_finder is None:
return

lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
Expand Down Expand Up @@ -136,11 +130,7 @@ def lr_find(
trainer.fit(model)
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
return

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')
save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp.ckpt')

__lr_finder_dump_params(trainer, model)

Expand Down Expand Up @@ -191,11 +181,8 @@ def lr_find(
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose

# Reset model state
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
__lr_finder_restore_params(trainer, model)
Expand Down
21 changes: 0 additions & 21 deletions tests/trainer/flags/test_fast_dev_run.py

This file was deleted.

0 comments on commit 7625e49

Please sign in to comment.