Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datamodule parameter to lr_find() #3425

Merged
merged 13 commits into from
Oct 1, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)

- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
28 changes: 19 additions & 9 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
from typing import Optional, Sequence, List, Union
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.optim.lr_scheduler import _LRScheduler
import importlib
from pytorch_lightning import _logger as log
import numpy as np
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr


# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
Expand Down Expand Up @@ -71,6 +73,7 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
):
r"""
lr_find enables the user to do a range test of good initial learning rates,
Expand All @@ -81,7 +84,7 @@ def lr_find(

train_dataloader: A PyTorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
a predefined train_dataloader method, this will be skipped.

min_lr: minimum learning rate to investigate

Expand All @@ -98,6 +101,12 @@ def lr_find(
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.

datamodule: An optional `LightningDataModule` which holds the training
and validation dataloader(s). Note that the `train_dataloader` and
`val_dataloaders` parameters cannot be used at the same time as
this parameter, or a `MisconfigurationException` will be raised.


Example::

# Setup model and trainer
Expand Down Expand Up @@ -167,7 +176,8 @@ def lr_find(
# Fit, lr & loss logged in callback
trainer.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)
val_dataloaders=val_dataloaders,
datamodule=datamodule)

# Prompt if we stopped early
if trainer.global_step != num_training:
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from typing import Optional, List, Union
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -50,6 +51,7 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None
):
return lr_find(
self.trainer,
Expand All @@ -60,7 +62,8 @@ def lr_find(
max_lr,
num_training,
mode,
early_stop_threshold
early_stop_threshold,
datamodule,
)

def internal_find_lr(self, trainer, model: LightningModule):
Expand Down
25 changes: 25 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule


def test_error_on_more_than_1_optimizer(tmpdir):
Expand Down Expand Up @@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir):
'Learning rate was not altered after running learning rate finder'


def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """

# trial datamodule
dm = TrialMNISTDataModule(tmpdir)

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

before_lr = hparams.get('learning_rate')
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
)

lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
after_lr = lrfinder.suggestion()
model.learning_rate = after_lr

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'


def test_accumulation_and_early_stopping(tmpdir):
""" Test that early stopping of learning rate finder works, and that
accumulation also works for this feature """
Expand Down