Skip to content

Commit

Permalink
fix grad scaler for MPS, which doesn't support FP64
Browse files Browse the repository at this point in the history
- remove fp64 intermediate cast
  • Loading branch information
masc-it committed Dec 31, 2024
1 parent 01034e9 commit 03e39b0
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions torch/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,16 @@ def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
)

@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
def scale(self, outputs: torch.Tensor) -> torch.Tensor: ...

@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]: ...

@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: ...

@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ...

def scale(
self,
Expand Down Expand Up @@ -331,8 +327,12 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
raise RuntimeError("unscale_() is being called after step().")

# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
# (except for MPS, which does not support FP64)
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
if self._scale.device.type == "mps":
inv_scale = self._scale.reciprocal()
else:
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)

optimizer_state["found_inf_per_device"] = self._unscale_grads_(
Expand Down Expand Up @@ -496,8 +496,10 @@ def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None
if isinstance(new_scale, float):
self._scale.fill_(new_scale)
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
reason = (
"new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
torch.FloatTensor with requires_grad=False."
)
assert new_scale.device.type == self._device, reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
Expand Down Expand Up @@ -675,9 +677,9 @@ def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, A
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)

self._per_optimizer_states[id(optimizer)][
"found_inf_per_device"
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = (
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
)

return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

Expand Down

0 comments on commit 03e39b0

Please sign in to comment.