diff --git a/docs/en/common_usage/speed_up_training.md b/docs/en/common_usage/speed_up_training.md index 9350cbf4fc..96b9e00e85 100644 --- a/docs/en/common_usage/speed_up_training.md +++ b/docs/en/common_usage/speed_up_training.md @@ -60,7 +60,7 @@ MMEngine supports training models with CPU, single GPU, multiple GPUs in single ## Mixed Precision Training -Nvidia introduced the Tensor Core unit into the Volta and Turing architectures to support FP32 and FP16 mixed precision computing. With automatic mixed precision training enabled, some operators operate at FP16 and the rest operate at FP32, which reduces training time and storage requirements without changing the model or degrading its training precision, thus supporting training with larger batch sizes, larger models, and larger input sizes. +Nvidia introduced the Tensor Core unit into the Volta and Turing architectures to support FP32 and FP16 mixed precision computing. They further support BF16 in Ampere architectures. With automatic mixed precision training enabled, some operators operate at FP16/BF16 and the rest operate at FP32, which reduces training time and storage requirements without changing the model or degrading its training precision, thus supporting training with larger batch sizes, larger models, and larger input sizes. [PyTorch officially supports amp from 1.6](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/). If you are interested in the implementation of automatic mixing precision, you can refer to [Mixed Precision Training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html). @@ -71,8 +71,16 @@ runner = Runner( model=ResNet18(), work_dir='./work_dir', train_dataloader=train_dataloader_cfg, - optim_wrapper=dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), + optim_wrapper=dict( + type='AmpOptimWrapper', + # If you want to use bfloat16, uncomment the following line + # dtype='bfloat16', # valid values: ('float16', 'bfloat16', None) + optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), train_cfg=dict(by_epoch=True, max_epochs=3), ) runner.train() ``` + +```{warning} +Up till PyTorch 1.13, `torch.bfloat16` performance on `Convolution` is bad unless manually set environment variable `TORCH_CUDNN_V8_API_ENABLED=1`. More context at [PyTorch issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767) +``` diff --git a/docs/zh_cn/common_usage/speed_up_training.md b/docs/zh_cn/common_usage/speed_up_training.md index 0aef22437b..7f7b86b778 100644 --- a/docs/zh_cn/common_usage/speed_up_training.md +++ b/docs/zh_cn/common_usage/speed_up_training.md @@ -61,7 +61,7 @@ MMEngine 支持 CPU、单卡、单机多卡以及多机多卡的训练。当环 ## 混合精度训练 -Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。 +Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。在 Ampere 架构中,他们进一步支持了 BF16 计算。开启自动混合精度训练后,部分算子的操作精度是 FP16/BF16,其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。 [PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。 @@ -72,8 +72,16 @@ runner = Runner( model=ResNet18(), work_dir='./work_dir', train_dataloader=train_dataloader_cfg, - optim_wrapper=dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), + optim_wrapper=dict( + type='AmpOptimWrapper', + # 如果你想要使用 BF16,请取消下面一行的代码注释 + # dtype='bfloat16', # 可用值: ('float16', 'bfloat16', None) + optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), train_cfg=dict(by_epoch=True, max_epochs=3), ) runner.train() ``` + +```{warning} +截止到 PyTorch 1.13 版本,在 `Convolution` 中直接使用 `torch.bfloat16` 性能低下,必须手动设置环境变量 `TORCH_CUDNN_V8_API_ENABLED=1` 以启用 CuDNN 版本的 BF16 Convolution。相关讨论见 [PyTorch Issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767) +``` diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 66f1255643..fcc3b4c50f 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from typing import Union import torch import torch.nn as nn @@ -38,15 +39,30 @@ class AmpOptimWrapper(OptimWrapper): - float: Initialize GradScaler with ``init_scale``. - dict: Initialize GradScaler with more detail configuration. + dtype (str or torch.dtype, optional): The data type to autocast in amp. + If a ``str`` is given, it will be converted to ``torch.dtype``. + Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and + `'float64'`. If set to ``None``, the default data type will be used. + Defaults to None. + `New in version 0.6.1.` **kwargs: Keyword arguments passed to OptimWrapper. + Warnings: + ``dtype`` argument is only available with PyTorch version >= 1.10.0. If + you use PyTorch of an older version, it will be ignored. + Note: If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. """ - def __init__(self, loss_scale='dynamic', **kwargs): + valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') + + def __init__(self, + loss_scale: str = 'dynamic', + dtype: Union[str, torch.dtype] = None, + **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') assert is_cuda_available() or is_npu_available(), ( @@ -68,6 +84,16 @@ def __init__(self, loss_scale='dynamic', **kwargs): raise TypeError('loss_scale must be of type float, dict, or ' f'"dynamic", but got {loss_scale}') + # convert string value to torch.dtype + if isinstance(dtype, str): + assert dtype in self.valid_dtypes, ( + f'dtype should be any of {self.valid_dtypes}, got {dtype}') + dtype = getattr(torch, dtype) + + assert dtype is None or isinstance(dtype, torch.dtype), ( + f'dtype should be None or instance of torch.dtype, got {dtype}') + self.cast_dtype = dtype + def backward(self, loss: torch.Tensor, **kwargs): """Perform gradient back propagation with :attr:`loss_scaler`. @@ -133,5 +159,5 @@ def optim_context(self, model: nn.Module): model (nn.Module): The training model. """ from mmengine.runner.amp import autocast - with super().optim_context(model), autocast(): + with super().optim_context(model), autocast(dtype=self.cast_dtype): yield diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index d00033f210..5ebebcb4c2 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -4,10 +4,10 @@ from unittest import TestCase from unittest.mock import MagicMock -import pytest import torch import torch.distributed as torch_dist import torch.nn as nn +from parameterized import parameterized from torch.cuda.amp import GradScaler from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD, Adam, Optimizer @@ -17,8 +17,6 @@ from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase -from mmengine.utils import digit_version -from mmengine.utils.dl_utils import TORCH_VERSION is_apex_available = False try: @@ -27,6 +25,17 @@ except ImportError: pass +amp_valid_dtypes = ['float64', 'float32', 'float16', 'bfloat16', None] +torch_dtypes = [ + torch.float16 if dtype is None else getattr(torch, dtype) + for dtype in amp_valid_dtypes +] + + +def bf16_supported() -> bool: + return (hasattr(torch.cuda, 'is_bf16_supported') + and torch.cuda.is_bf16_supported()) + class ToyModel(nn.Module): @@ -196,7 +205,7 @@ def test_step(self): # TODO: This unit test could cause CI to fail with some probability, which # is caused by MultiProcessTestCase. This problem should be solved # in the future). - @pytest.mark.skipif(True, reason='Solved in the future') + @unittest.skipIf(True, reason='Solved in the future') def test_clip_grads(self): # Test `clip_grad` with `clip_norm_` optim_wrapper = OptimWrapper( @@ -392,10 +401,8 @@ def setUp(self) -> None: self.optimizer = SGD(self.model.parameters(), lr=0.1) @unittest.skipIf( - not torch.cuda.is_available() - and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), - reason='`torch.cuda.amp` is only available when pytorch-gpu version ' - '>= 1.6') + not torch.cuda.is_available(), + reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_init(self): # Test with default arguments. amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) @@ -407,6 +414,16 @@ def test_init(self): self.assertIsNone(amp_optim_wrapper._scale_update_param) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + # Test with dtype float16 + amp_optim_wrapper = AmpOptimWrapper( + dtype='float16', optimizer=self.optimizer) + self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16) + + # Test with dtype bfloat16 + amp_optim_wrapper = AmpOptimWrapper( + dtype='bfloat16', optimizer=self.optimizer) + self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16) + # Test with dict loss_scale. amp_optim_wrapper = AmpOptimWrapper( dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) @@ -416,14 +433,15 @@ def test_init(self): 'loss_scale must be of type float'): AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown') + @parameterized.expand(list(zip(amp_valid_dtypes))) @unittest.skipIf( - not torch.cuda.is_available() - and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), - reason='`torch.cuda.amp` is only available when pytorch-gpu version ' - '>= 1.6') - def test_step(self): + not torch.cuda.is_available(), + reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + def test_step(self, dtype): + if dtype == 'bfloat16' and not bf16_supported(): + raise unittest.SkipTest('bfloat16 not supported by device') optimizer = MagicMock(spec=Optimizer) - amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer) + amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype) amp_optim_wrapper.loss_scaler = MagicMock() amp_optim_wrapper.step() amp_optim_wrapper.loss_scaler.step.assert_called_with( @@ -431,13 +449,15 @@ def test_step(self): amp_optim_wrapper.loss_scaler.update.assert_called_with( amp_optim_wrapper._scale_update_param) + @parameterized.expand(list(zip(amp_valid_dtypes))) @unittest.skipIf( - not torch.cuda.is_available() - and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), - reason='`torch.cuda.amp` is only available when pytorch-gpu version ' - '>= 1.6') - def test_backward(self): - amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + not torch.cuda.is_available(), + reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + def test_backward(self, dtype): + if dtype == 'bfloat16' and not bf16_supported(): + raise unittest.SkipTest('bfloat16 not supported by device') + amp_optim_wrapper = AmpOptimWrapper( + optimizer=self.optimizer, dtype=dtype) loss_scaler = MagicMock() scale_return = MagicMock() scale_fn = MagicMock(return_value=scale_return) @@ -449,10 +469,8 @@ def test_backward(self): scale_return.backward.assert_called_with() @unittest.skipIf( - not torch.cuda.is_available() - and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), - reason='`torch.cuda.amp` is only available when pytorch-gpu version ' - '>= 1.6') + not torch.cuda.is_available(), + reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_state_dict(self): self.model = self.model.cuda() amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) @@ -468,10 +486,8 @@ def test_state_dict(self): amp_optim_wrapper.loss_scaler.state_dict()) @unittest.skipIf( - not torch.cuda.is_available() - and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), - reason='`torch.cuda.amp` is only available when pytorch-gpu version ' - '>= 1.6') + not torch.cuda.is_available(), + reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_load_state_dict(self): amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) self.model = self.model.cuda() @@ -491,14 +507,16 @@ def test_load_state_dict(self): self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(), amp_optim_wrapper_.loss_scaler.state_dict()) + @parameterized.expand(list(zip(amp_valid_dtypes, torch_dtypes))) @unittest.skipIf( - not torch.cuda.is_available() - and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), - reason='`torch.cuda.amp` is only available when pytorch-gpu version ' - '>= 1.6') - def test_optim_context(self): - amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + not torch.cuda.is_available(), + reason='`torch.cuda.amp` is only available when pytorch-gpu installed') + def test_optim_context(self, dtype, target_dtype): + if dtype == 'bfloat16' and not bf16_supported(): + raise unittest.SkipTest('bfloat16 not supported by device') + amp_optim_wrapper = AmpOptimWrapper( + optimizer=self.optimizer, dtype=dtype) with amp_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) - self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.dtype, target_dtype)