-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MobileNetV3 Architecture in TorchVision #3182
Conversation
…sses and methods.
# Conflicts: # torchvision/models/mobilenet.py # torchvision/models/quantization/mobilenet.py
Codecov Report
@@ Coverage Diff @@
## master #3182 +/- ##
==========================================
+ Coverage 73.49% 73.79% +0.30%
==========================================
Files 101 102 +1
Lines 9235 9354 +119
Branches 1477 1490 +13
==========================================
+ Hits 6787 6903 +116
- Misses 1991 1993 +2
- Partials 457 458 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice PR @datumbox !
Few nits...
The failing builds seem unrelated. See issue #3183 |
385e077
to
585374c
Compare
6ba7f15
to
d912443
Compare
d912443
to
5198385
Compare
b415a70
to
e4d130f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a few comments to assist review.
# TODO: add pretrained | ||
model_urls = { | ||
"mobilenet_v3_large": None, | ||
"mobilenet_v3_small": None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pending S3 bucket access and training finalization.
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) | ||
elif opt_name == 'rmsprop': | ||
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, | ||
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These hardcoded params are crucial for convergence =. They can be turned into args.
403396c
to
c0a13a2
Compare
I will merge this PR on a separate branch to continue with the changes necessary for Object Detection. I'll send a new PR on master once all changes are final. |
* Add MobileNetV3 Architecture in TorchVision (#3182) * Adding implementation of network architecture * Adding rmsprop support on the train.py * Adding auto-augment and random-erase in the training scripts. * Adding support for reduced tail on MobileNetV3. * Tagging blocks with comments. * Adding documentation, pre-trained model URL and a minor refactoring. * Handling better untrained supported models.
Summary: * Add MobileNetV3 Architecture in TorchVision (#3182) * Adding implementation of network architecture * Adding rmsprop support on the train.py * Adding auto-augment and random-erase in the training scripts. * Adding support for reduced tail on MobileNetV3. * Tagging blocks with comments. * Adding documentation, pre-trained model URL and a minor refactoring. * Handling better untrained supported models. Reviewed By: datumbox Differential Revision: D25954557 fbshipit-source-id: f7d72a81a2ec92cbbbf3bd86c68ae0a426626cc7
Partially fixes #1676
Depends and cherrypicks commits from #3177
The current temporary pre-trained model was trained:
Submitted batch job 34241491
Then we took the 3 last checkpoints (epochs 549, 528, 408) that improved the Acc@1 and averaged their parameters using the following script:
Validated with:
Submitted batch job 34643680
Accuracy metrics:
Acc@1 74.042 Acc@5 91.340
Speed Benchmark:
0.0411 sec per image on CPU