Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileNetV3 Architecture in TorchVision #3182

Merged
merged 26 commits into from
Jan 5, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Dec 16, 2020

Partially fixes #1676

Depends and cherrypicks commits from #3177


The current temporary pre-trained model was trained:

python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--model mobilenet_v3_large --epochs 600 --opt rmsprop --batch-size 128 --lr 0.064\ 
--wd 0.00001 --lr-step-size 2 --lr-gamma 0.973 --auto-augment imagenet --random-erase 0.2

Submitted batch job 34241491

Then we took the 3 last checkpoints (epochs 549, 528, 408) that improved the Acc@1 and averaged their parameters using the following script:

# from https://github.com/pytorch/fairseq/blob/master/scripts/average_checkpoints.py
import collections
import torch


def average_checkpoints(inputs):
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)
    for fpath in inputs:
        with open(fpath, "rb") as f:
            state = torch.load(
                f,
                map_location=(
                    lambda s, _: torch.serialization.default_restore_location(s, "cpu")
                ),
            )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state
        model_params = state["model"]
        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                "For checkpoint {}, expected list of params: {}, "
                "but found: {}".format(f, params_keys, model_params_keys)
            )
        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p
    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
        averaged_params[k] = v
        if averaged_params[k].is_floating_point():
            averaged_params[k].div_(num_models)
        else:
            averaged_params[k] //= num_models
    new_state["model"] = averaged_params
    return new_state


def avg(epochs, filename):
    paths = ["model_{}.pth".format(i) for i in epochs]
    weights = average_checkpoints(paths)
    torch.save(weights, filename.format(len(epochs)))

avg([549, 528, 408], "model_best{}avg.pth")

Validated with:

python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
 --model mobilenet_v3_large --test-only --pretrained

Submitted batch job 34643680

Accuracy metrics:
Acc@1 74.042 Acc@5 91.340

Speed Benchmark: 0.0411 sec per image on CPU

@codecov
Copy link

codecov bot commented Dec 16, 2020

Codecov Report

Merging #3182 (5030435) into master (4cbe714) will increase coverage by 0.30%.
The diff coverage is 94.95%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3182      +/-   ##
==========================================
+ Coverage   73.49%   73.79%   +0.30%     
==========================================
  Files         101      102       +1     
  Lines        9235     9354     +119     
  Branches     1477     1490      +13     
==========================================
+ Hits         6787     6903     +116     
- Misses       1991     1993       +2     
- Partials      457      458       +1     
Impacted Files Coverage Δ
torchvision/models/mobilenetv3.py 94.91% <94.91%> (ø)
torchvision/models/mobilenet.py 100.00% <100.00%> (ø)
torchvision/models/mobilenetv2.py 86.51% <0.00%> (+3.37%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4cbe714...9a758a8. Read the comment docs.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR @datumbox !
Few nits...

torchvision/models/mobilenetv3.py Outdated Show resolved Hide resolved
torchvision/models/mobilenetv3.py Show resolved Hide resolved
@datumbox
Copy link
Contributor Author

The failing builds seem unrelated. See issue #3183

@datumbox datumbox force-pushed the models/mobilenetv3 branch 3 times, most recently from 385e077 to 585374c Compare December 20, 2020 20:26
@datumbox datumbox force-pushed the models/mobilenetv3 branch 2 times, most recently from 6ba7f15 to d912443 Compare December 30, 2020 20:16
@datumbox datumbox force-pushed the models/mobilenetv3 branch from d912443 to 5198385 Compare January 1, 2021 12:08
@datumbox datumbox force-pushed the models/mobilenetv3 branch from b415a70 to e4d130f Compare January 3, 2021 20:48
@datumbox datumbox mentioned this pull request Jan 5, 2021
13 tasks
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few comments to assist review.

# TODO: add pretrained
model_urls = {
"mobilenet_v3_large": None,
"mobilenet_v3_small": None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending S3 bucket access and training finalization.

model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hardcoded params are crucial for convergence =. They can be turned into args.

@datumbox datumbox force-pushed the models/mobilenetv3 branch from 403396c to c0a13a2 Compare January 5, 2021 14:57
@datumbox datumbox changed the base branch from master to mobilenetv3 January 5, 2021 18:31
@datumbox datumbox changed the title [WIP] Add MobileNetV3 in TorchVision Add MobileNetV3 Architecture in TorchVision Jan 5, 2021
@datumbox
Copy link
Contributor Author

datumbox commented Jan 5, 2021

I will merge this PR on a separate branch to continue with the changes necessary for Object Detection. I'll send a new PR on master once all changes are final.

@datumbox datumbox merged commit aea1191 into pytorch:mobilenetv3 Jan 5, 2021
@datumbox datumbox deleted the models/mobilenetv3 branch January 5, 2021 19:18
datumbox added a commit that referenced this pull request Jan 14, 2021
* Add MobileNetV3 Architecture in TorchVision (#3182)

* Adding implementation of network architecture

* Adding rmsprop support on the train.py

* Adding auto-augment and random-erase in the training scripts.

* Adding support for reduced tail on MobileNetV3.

* Tagging blocks with comments.

* Adding documentation, pre-trained model URL and a minor refactoring.

* Handling better untrained supported models.
facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2021
Summary:
* Add MobileNetV3 Architecture in TorchVision (#3182)

* Adding implementation of network architecture

* Adding rmsprop support on the train.py

* Adding auto-augment and random-erase in the training scripts.

* Adding support for reduced tail on MobileNetV3.

* Tagging blocks with comments.

* Adding documentation, pre-trained model URL and a minor refactoring.

* Handling better untrained supported models.

Reviewed By: datumbox

Differential Revision: D25954557

fbshipit-source-id: f7d72a81a2ec92cbbbf3bd86c68ae0a426626cc7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Add MobileNet v3 to torchvision
4 participants