Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Torch AMP + ShardedDDP (#667)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #667

Adding mixed precision support in conjunction with Classy

Reviewed By: mannatsingh

Differential Revision: D25400577

fbshipit-source-id: 687eefac5cb4ea24d89ac22a92492389d136f015
  • Loading branch information
blefaudeux authored and facebook-github-bot committed Dec 11, 2020
1 parent ff37fea commit a17e50a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
"Apex AMP is required but Apex is not installed, cannot enable AMP"
)

# Set Torch AMP grad scaler, used to prevent gradient underflow
elif self.amp_type == AmpType.PYTORCH:
self.amp_grad_scaler = TorchGradScaler()

logging.info(f"AMP enabled with args {amp_args}")
return self

Expand Down Expand Up @@ -744,8 +748,6 @@ def prepare(self):
self.base_model, self.optimizer.optimizer = apex.amp.initialize(
self.base_model, self.optimizer.optimizer, **self.amp_args
)
elif self.amp_type == AmpType.PYTORCH:
self.amp_grad_scaler = TorchGradScaler()

if self.simulated_global_batchsize is not None:
if self.simulated_global_batchsize % self.get_global_batchsize() != 0:
Expand Down

0 comments on commit a17e50a

Please sign in to comment.