Skip to content

Commit

Permalink
Update PT to PL conversion doc (Lightning-AI#11397)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
  • Loading branch information
3 people authored Feb 21, 2022
1 parent 9b0942d commit f811284
Showing 1 changed file with 149 additions and 32 deletions.
181 changes: 149 additions & 32 deletions docs/source/starter/converting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

.. _converting:

**************************************

######################################
How to organize PyTorch into Lightning
**************************************
######################################

To enable your code to work with Lightning, here's how to organize PyTorch into Lightning
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning:

--------

1. Move your computational code
===============================
Move the model architecture and forward pass to your :doc:`lightning module <../common/lightning_module>`.
*******************************
1. Move your Computational Code
*******************************

Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.lightning.LightningModule`.

.. testcode::

Expand All @@ -35,23 +38,32 @@ Move the model architecture and forward pass to your :doc:`lightning module <../

--------

2. Move the optimizer(s) and schedulers
=======================================
Move your optimizers to the :func:`~pytorch_lightning.core.LightningModule.configure_optimizers` hook.
********************************************
2. Move the Optimizer(s) and LR Scheduler(s)
********************************************

Move your optimizers to the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` hook.

.. testcode::

class LitModel(LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

--------

3. Find the train loop "meat"
=============================
Lightning automates most of the training for you, the epoch and batch iterations, all you need to keep is the training step logic.
This should go into the :func:`~pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case):
*******************************
3. Configure the Training Logic
*******************************

Lightning automates the training loop for you and manages all of the associated components such as: epoch and batch tracking, optimizers and schedulers,
and metric reduction. As a user, you just need to define how your model behaves with a batch of training data within the
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method. When using Lightning, simply override the
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method which takes the current ``batch`` and the ``batch_idx``
as arguments. Optionally, it can take ``optimizer_idx`` if your LightningModule defines multiple optimizers within its
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` hook.

.. testcode::

Expand All @@ -64,10 +76,17 @@ This should go into the :func:`~pytorch_lightning.core.LightningModule.training_

--------

4. Find the val loop "meat"
===========================
*********************************
4. Configure the Validation Logic
*********************************

Lightning also automates the validation loop for you and manages all of the associated components such as: epoch and batch tracking, and metrics reduction. As a user,
you just need to define how your model behaves with a batch of validation data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method which takes the current
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.

To add an (optional) validation loop add logic to the
:func:`~pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).

.. testcode::

Expand All @@ -76,38 +95,136 @@ To add an (optional) validation loop add logic to the
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return val_loss
self.log("val_loss", val_loss)

Additionally, you can run only the validation loop using :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` method.

.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation
.. code-block:: python
model = LitModel()
trainer.validate(model)
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation.

.. tip:: ``trainer.validate()`` loads the best checkpoint automatically by default if checkpointing was enabled during fitting.

--------

5. Find the test loop "meat"
============================
To add an (optional) test loop add logic to the
:func:`~pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
**************************
5. Configure Testing Logic
**************************

Lightning automates the testing loop for you and manages all the associated components, such as epoch and batch tracking, metrics reduction. As a user,
you just need to define how your model behaves with a batch of testing data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method which takes the current
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.

.. testcode::

class LitModel(LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)

The test loop isn't used within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.

.. code-block:: python
model = LitModel()
trainer.test(model)
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.

.. tip:: ``trainer.test()`` loads the best checkpoint automatically by default if checkpointing is enabled.

--------

*****************************
6. Configure Prediction Logic
*****************************

Lightning automates the prediction loop for you and manages all of the associated components such as epoch and batch tracking. As a user,
you just need to define how your model behaves with a batch of data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method which takes the current
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
If you don't override ``predict_step`` hook, it by default calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method on the batch.

.. testcode::

class LitModel(LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
return pred

The predict loop will not be used until you call :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.

.. code-block:: python
model = LitModel()
trainer.predict(model)
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.

The test loop will not be used until you call.
.. tip:: ``trainer.predict()`` loads the best checkpoint automatically by default if checkpointing is enabled.

--------

******************************************
7. Remove any .cuda() or .to(device) Calls
******************************************

Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!

If you have any explicit calls to ``.cuda()`` or ``.to(device)``, you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader`
and all the :class:`~torch.nn.Module` instances initialized inside ``LightningModule.__init__`` are moved to the respective devices automatically.

.. testcode::

class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))

.. code-block::
If you still need to access the current device, you can use ``self.device`` anywhere in ``LightningModule`` except ``__init__`` method. You are initializing a
:class:`~torch.Tensor` within ``LightningModule.__init__`` method and want it to be moved to the device automatically you must :meth:`~torch.nn.Module.register_buffer`
to register it as a parameter.

trainer.test()
.. testcode::

.. tip:: ``.test()`` loads the best checkpoint automatically
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
z = torch.randn(4, 5, device=self.device)
...

--------

6. Remove any .cuda() or to.device() calls
==========================================
Your :doc:`lightning module <../common/lightning_module>` can automatically run on any hardware!
********************
8. Use your own data
********************

To use your DataLoaders, you can override the respective dataloader hooks in the :class:`~pytorch_lightning.core.lightning.LightningModule`:

.. testcode::

class LitModel(LightningModule):
def train_dataloader(self):
return DataLoader(...)

def val_dataloader(self):
return DataLoader(...)

def test_dataloader(self):
return DataLoader(...)

def predict_dataloader(self):
return DataLoader(...)

Alternatively, you can pass your dataloaders in one of the following ways:

* Pass in the dataloaders explictly inside ``trainer.fit/.validate/.test/.predict`` calls.
* Use a :ref:`LightningDataModule <datamodules>`.

Checkout :ref:`data` doc to understand data management within Lightning.

0 comments on commit f811284

Please sign in to comment.