Skip to content

Commit

Permalink
Avoid torch amp cuda warning with bf16 on cpu (#11161)
Browse files Browse the repository at this point in the history
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people committed Dec 21, 2021
1 parent 1521732 commit 60c276e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))
- Fixed an `AttributeError` occuring when using a `CombinedLoader` (multiple dataloaders) for prediction ([#11111](https://github.com/PyTorchLightning/pytorch-lightning/pull/11111))
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))
- Fixed an incorrect warning being produced by the model summary when using `bf16` precision on CPU ([#11161](https://github.com/PyTorchLightning/pytorch-lightning/pull/11161))

### Changed

Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import logging
import sys
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -23,7 +25,7 @@
from torch.utils.hooks import RemovableHandle

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down Expand Up @@ -282,12 +284,17 @@ def _forward_example_input(self) -> None:
input_ = model.example_input_array
input_ = model._apply_batch_transfer_handler(input_)

if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU:
model.forward = torch.cuda.amp.autocast()(model.forward)

mode = model.training
model.eval()
with torch.no_grad():

if trainer is not None:
forward_context = trainer.precision_plugin.forward_context()
elif sys.version_info >= (3, 7):
forward_context = contextlib.nullcontext()
else:
forward_context = contextlib.suppress()

with torch.no_grad(), forward_context:
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
model(*input_)
Expand Down

0 comments on commit 60c276e

Please sign in to comment.