Skip to content

Commit

Permalink
Learning Rate finder (#1347)
Browse files Browse the repository at this point in the history
* initial structure

* rebase

* incorporate suggestions

* update CHANGELOG.md

* initial docs

* fixes based on reviews

* added trainer arg

* update docs

* added saving/restore of model state

* initial tests

* fix styling

* added more tests

* fix docs, backward compatility and progressbar

* fix styling

* docs update

* updates based on review

* changed saving to standard functions

* consistent naming

* fix formatting

* improve docs, added support for nested fields, improve codecov

* update CHANGELOG.md

* Update lr_finder.rst

* Update pytorch_lightning/trainer/trainer.py

* Update trainer.py

* Update CHANGELOG.md

* Update path

* restoring

* test

* attribs

* docs

* doc typo

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
5 people authored Apr 10, 2020
1 parent d05ac81 commit 3f09b32
Show file tree
Hide file tree
Showing 9 changed files with 773 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))

-

Expand Down
Binary file added docs/source/_images/trainer/lr_finder.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ PyTorch Lightning Documentation
fast_training
hooks
hyperparameters
lr_finder
multi_gpu
multiple_loaders
weights_loading
Expand Down
108 changes: 108 additions & 0 deletions docs/source/lr_finder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
Learning Rate Finder
--------------------

For training deep neural networks, selecting a good learning rate is essential
for both better performance and faster convergence. Even optimizers such as
`Adam` that are self-adjusting the learning rate can benefit from more optimal
choices.

To reduce the amount of guesswork concerning choosing a good initial learning
rate, a `learning rate finder` can be used. As described in this `paper <https://arxiv.org/abs/1506.01186>`_
a learning rate finder does a small run where the learning rate is increased
after each processed batch and the corresponding loss is logged. The result of
this is a `lr` vs. `loss` plot that can be used as guidence for choosing a optimal
initial lr.

.. warning:: For the moment, this feature only works with models having a single optimizer.

Using Lightnings build-in LR finder
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the most basic use case, this feature can be enabled during trainer construction
with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the lr finder
will automatically be run before any training is done. The ``lr`` that is found
and used will be written to the console and logged together with all other
hyperparameters of the model.

.. code-block:: python
# default, no automatic learning rate finder
Trainer(auto_lr_find=True)
When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate.
In both cases, if the respective fields are not found, an error will be thrown.

.. code-block:: python
class LitModel(LightningModule):
def __init__(self, hparams):
self.hparams = hparams
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate)
# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
Trainer(auto_lr_find=True)
To use an arbitrary value set it in the parameter.

.. code-block:: python
# to set to your own hparams.my_value
Trainer(auto_lr_find='my_value')
Under the hood, when you call fit, this is what happens.

1. Run learning rate finder.
2. Run actual fit.

.. code-block:: python
# when you call .fit() this happens
# 1. find learning rate
# 2. actually run fit
trainer.fit(model)
If you want to inspect the results of the learning rate finder before doing any
actual training or just play around with the parameters of the algorithm, this
can be done by invoking the ``lr_find`` method of the trainer. A typical example
of this would look like

.. code-block:: python
model = MyModelClass(hparams)
trainer = pl.Trainer()
# Run learning rate finder
lr_finder = trainer.lr_find(model)
# Results can be found in
lr_finder.results
# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()
# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()
# update hparams of the model
model.hparams.lr = new_lr
# Fit model
trainer.fit(model)
The figure produced by ``lr_finder.plot()`` should look something like the figure
below. It is recommended to not pick the learning rate that achives the lowest
loss, but instead something in the middle of the sharpest downward slope (red point).
This is the point returned py ``lr_finder.suggestion()``.

.. figure:: /_images/trainer/lr_finder.png

The parameters of the algorithm can be seen below.

.. autoclass:: pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin
:members: lr_find
:noindex:
:exclude-members: _run_lr_finder_internally, save_checkpoint, restore
21 changes: 21 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,27 @@ def forward(self, x):
# default used by the Trainer
trainer = Trainer(amp_level='O1')
auto_lr_find
^^^^^^^^^^^^
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
before any training, to find optimal initial learning rate.
.. code-block:: python
# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)
Example::
# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)
# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')
.. note::
See the `learning rate finder guide <lr_finder.rst>`_
benchmark
^^^^^^^^^
Expand Down
Loading

0 comments on commit 3f09b32

Please sign in to comment.