Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] enable bf16 in AmpOptimWrapper #960

Merged
merged 18 commits into from
Mar 1, 2023
Prev Previous commit
Next Next commit
add unittests for bf16 in AmpOptimWrapper
  • Loading branch information
C1rN09 committed Feb 24, 2023
commit e1d78910996760e6621491f2dba9824eb365bfc5
107 changes: 107 additions & 0 deletions tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ def test_init(self):
self.assertIsNone(amp_optim_wrapper._scale_update_param)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)

# Test with dtype float16
amp_optim_wrapper = AmpOptimWrapper(
dtype='float16', optimizer=self.optimizer)
self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16)

# Test with dtype bfloat16
amp_optim_wrapper = AmpOptimWrapper(
dtype='bfloat16', optimizer=self.optimizer)
self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16)

# Test with dict loss_scale.
amp_optim_wrapper = AmpOptimWrapper(
dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
Expand Down Expand Up @@ -502,3 +512,100 @@ def test_optim_context(self):
x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.float16)

@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0'))
and torch.cuda.is_bf16_supported(),
reason='`torch.cuda.amp` with bf16 is only available when pytorch-gpu'
'version >= 1.6 && bf16 supported by device')
def test_step_bf16(self):
optimizer = MagicMock(spec=Optimizer)
amp_optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, dtype='bfloat16')
amp_optim_wrapper.loss_scaler = MagicMock()
amp_optim_wrapper.step()
amp_optim_wrapper.loss_scaler.step.assert_called_with(
amp_optim_wrapper.optimizer)
amp_optim_wrapper.loss_scaler.update.assert_called_with(
amp_optim_wrapper._scale_update_param)

@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0'))
and torch.cuda.is_bf16_supported(),
reason='`torch.cuda.amp` with bf16 is only available when pytorch-gpu'
'version >= 1.6 && bf16 supported by device')
def test_backward_bf16(self):
amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype='bfloat16')
loss_scaler = MagicMock()
scale_return = MagicMock()
scale_fn = MagicMock(return_value=scale_return)
loss_scaler.scale = scale_fn
amp_optim_wrapper.loss_scaler = loss_scaler

amp_optim_wrapper.backward(1)
loss_scaler.scale.assert_called_with(1)
scale_return.backward.assert_called_with()

@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0'))
and torch.cuda.is_bf16_supported(),
reason='`torch.cuda.amp` with bf16 is only available when pytorch-gpu'
'version >= 1.6 && bf16 supported by device')
def test_state_dict_bf16(self):
self.model = self.model.cuda()
amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype='bfloat16')
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
amp_optim_wrapper.update_params(loss)
state_dict = amp_optim_wrapper.state_dict()
scalar_state_dict = state_dict.pop('loss_scaler')
optim_state_dict = state_dict

self.assertDictEqual(optim_state_dict,
amp_optim_wrapper.optimizer.state_dict())
self.assertDictEqual(scalar_state_dict,
amp_optim_wrapper.loss_scaler.state_dict())

@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0'))
and torch.cuda.is_bf16_supported(),
reason='`torch.cuda.amp` with bf16 is only available when pytorch-gpu'
'version >= 1.6 && bf16 supported by device')
def test_load_state_dict_bf16(self):
amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype='bfloat16')
self.model = self.model.cuda()
# Test load from optimizer
optimizer = SGD(self.model.parameters(), lr=0.1)
amp_optim_wrapper.load_state_dict(optimizer.state_dict())

self.assertDictEqual(optimizer.state_dict(),
amp_optim_wrapper.optimizer.state_dict())
# Test load from optim_wrapper
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
amp_optim_wrapper_ = AmpOptimWrapper(
optimizer=SGD(self.model.parameters(), lr=0.1))
amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict())
self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(),
amp_optim_wrapper_.optimizer.state_dict())
self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(),
amp_optim_wrapper_.loss_scaler.state_dict())

@unittest.skipIf(
not torch.cuda.is_available()
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0'))
and torch.cuda.is_bf16_supported(),
reason='`torch.cuda.amp` with bf16 is only available when pytorch-gpu'
'version >= 1.6 && bf16 supported by device')
def test_optim_context_bf16(self):
amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype='bfloat16')
with amp_optim_wrapper.optim_context(self.model):
x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.bfloat16)