-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] enable bf16 in AmpOptimWrapper #960
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #960 +/- ##
=======================================
Coverage ? 76.55%
=======================================
Files ? 138
Lines ? 10846
Branches ? 2167
=======================================
Hits ? 8303
Misses ? 2186
Partials ? 357
Flags with carried forward coverage won't be shown. Click here to find out more. Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Enable
torch.bfloat16
inAmpOptimWrapper
and config fileModification
New optional argument
dtype
inAmpOptimWrapper
, can bestr
ortorch.dtype
BC-breaking (Optional)
No
Use cases (Optional)
Test results
mmcls::resnet50_8xb256-rsb-a3-100e_in1k
Hint: Amp with
dtype=torch.bfloat16
works bad on convolutions, because it doesn't use CuDNN by default. Enable CuDNN version bfloat16 convolution by environment variable:TORCH_CUDNN_V8_API_ENABLED=1
mmdet::retinanet_r50_fpn_1x_coco
Failed due to
torch.bfloat16
not supported byF.interpolate
.Checklist