forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add tests on the creation of the optimizer group (facebookresearch#178)
Summary: In prevision for the work on the optimiser group enhancement for FSDP: - These tests will allow to catch some accidental issues that arise if refactoring is needed. - These tests add an important use case to our list of integration tests: fine tuning. Pull Request resolved: fairinternal/ssl_scaling#178 Reviewed By: prigoyal Differential Revision: D30233695 Pulled By: QuentinDuval fbshipit-source-id: ca67bc5f65328976298c3b09b17048dd4bfde934
- Loading branch information
1 parent
b780610
commit 1f08d77
Showing
6 changed files
with
380 additions
and
15 deletions.
There are no files selected for viewing
19 changes: 19 additions & 0 deletions
19
configs/config/test/integration_test/models/finetune_regnet_fsdp.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _global_ | ||
config: | ||
MODEL: | ||
TRUNK: | ||
NAME: regnet | ||
REGNET: | ||
name: anynet | ||
depths: [2, 4, 11, 1] | ||
widths: [224, 448, 1232, 3024] | ||
group_widths: [112, 112, 112, 112] | ||
bottleneck_multipliers: [1.0, 1.0, 1.0, 1.0] | ||
strides: [2, 2, 2, 2] | ||
HEAD: | ||
PARAMS: [ | ||
["mlp", {"dims": [3024, 1000]}], | ||
] | ||
SYNC_BN_CONFIG: | ||
CONVERT_BN_TO_SYNC_BN: True | ||
SYNC_BN_TYPE: pytorch |
117 changes: 117 additions & 0 deletions
117
configs/config/test/integration_test/quick_eval_finetune_in1k.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# @package _global_ | ||
config: | ||
VERBOSE: True | ||
LOG_FREQUENCY: 100 | ||
TEST_ONLY: False | ||
TEST_EVERY_NUM_EPOCH: 1 | ||
TEST_MODEL: True | ||
SEED_VALUE: 1 | ||
MULTI_PROCESSING_METHOD: forkserver | ||
HOOKS: | ||
PERF_STATS: | ||
MONITOR_PERF_STATS: True | ||
DATA: | ||
NUM_DATALOADER_WORKERS: 5 | ||
TRAIN: | ||
DATA_SOURCES: [disk_folder] | ||
LABEL_SOURCES: [disk_folder] | ||
DATASET_NAMES: [imagenet1k_folder] | ||
BATCHSIZE_PER_REPLICA: 32 | ||
TRANSFORMS: | ||
- name: RandomResizedCrop | ||
size: 224 | ||
- name: RandomHorizontalFlip | ||
- name: ToTensor | ||
- name: Normalize | ||
mean: [0.485, 0.456, 0.406] | ||
std: [0.229, 0.224, 0.225] | ||
MMAP_MODE: True | ||
COPY_TO_LOCAL_DISK: False | ||
COPY_DESTINATION_DIR: /tmp/imagenet1k/ | ||
TEST: | ||
DATA_SOURCES: [disk_folder] | ||
LABEL_SOURCES: [disk_folder] | ||
DATASET_NAMES: [imagenet1k_folder] | ||
BATCHSIZE_PER_REPLICA: 32 | ||
TRANSFORMS: | ||
- name: Resize | ||
size: 256 | ||
- name: CenterCrop | ||
size: 224 | ||
- name: ToTensor | ||
- name: Normalize | ||
mean: [0.485, 0.456, 0.406] | ||
std: [0.229, 0.224, 0.225] | ||
MMAP_MODE: True | ||
COPY_TO_LOCAL_DISK: False | ||
COPY_DESTINATION_DIR: /tmp/imagenet1k/ | ||
METERS: | ||
name: accuracy_list_meter | ||
accuracy_list_meter: | ||
num_meters: 1 | ||
topk_values: [1, 5] | ||
TRAINER: | ||
TRAIN_STEP_NAME: standard_train_step | ||
MODEL: | ||
FEATURE_EVAL_SETTINGS: | ||
EVAL_MODE_ON: True | ||
EVAL_TRUNK_AND_HEAD: False | ||
TRUNK: | ||
NAME: resnet | ||
RESNETS: | ||
DEPTH: 50 | ||
HEAD: | ||
PARAMS: [ | ||
["mlp", {"dims": [2048, 1000]}], | ||
] | ||
WEIGHTS_INIT: | ||
PARAMS_FILE: "specify model weights" | ||
STATE_DICT_KEY_NAME: classy_state_dict | ||
APPEND_PREFIX: trunk. | ||
LOSS: | ||
name: cross_entropy_multiple_output_single_target | ||
cross_entropy_multiple_output_single_target: | ||
ignore_index: -1 | ||
OPTIMIZER: | ||
name: sgd | ||
weight_decay: 0.0001 | ||
momentum: 0.9 | ||
num_epochs: 10 | ||
nesterov: True | ||
regularize_bn: False | ||
regularize_bias: True | ||
head_optimizer_params: | ||
use_different_lr: True | ||
use_different_wd: True | ||
weight_decay: 0.000001 | ||
param_schedulers: | ||
lr: | ||
auto_lr_scaling: | ||
auto_scale: true | ||
base_value: 0.1 | ||
base_lr_batch_size: 256 | ||
name: cosine | ||
start_value: 0.01 | ||
end_value: 0.0001 | ||
update_interval: epoch | ||
lr_head: | ||
auto_lr_scaling: | ||
auto_scale: true | ||
base_value: 0.1 | ||
base_lr_batch_size: 256 | ||
name: cosine | ||
start_value: 0.1 | ||
end_value: 0.001 | ||
update_interval: epoch | ||
DISTRIBUTED: | ||
BACKEND: nccl | ||
NUM_NODES: 1 | ||
NUM_PROC_PER_NODE: 8 | ||
INIT_METHOD: tcp | ||
RUN_ID: auto | ||
MACHINE: | ||
DEVICE: gpu | ||
CHECKPOINT: | ||
DIR: "." | ||
AUTO_RESUME: True | ||
CHECKPOINT_FREQUENCY: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import os | ||
import unittest | ||
|
||
from classy_vision.optim import build_optimizer_schedulers | ||
from hydra.experimental import compose, initialize_config_module | ||
from vissl.models import build_model | ||
from vissl.optimizers import get_optimizer_param_groups | ||
from vissl.utils.hydra_config import convert_to_attrdict | ||
from vissl.utils.test_utils import ( | ||
gpu_test, | ||
in_temporary_directory, | ||
init_distributed_on_file, | ||
run_integration_test, | ||
with_temp_files, | ||
) | ||
|
||
|
||
class TestFineTuning(unittest.TestCase): | ||
@staticmethod | ||
def _create_pretraining_config(num_gpu: int = 2, with_fsdp: bool = False): | ||
with initialize_config_module(config_module="vissl.config"): | ||
cfg = compose( | ||
"defaults", | ||
overrides=[ | ||
"config=test/integration_test/quick_swav", | ||
"+config/test/integration_test/models=swav_regnet_fsdp", | ||
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]", | ||
"config.DATA.TRAIN.DATA_LIMIT=40", | ||
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4", | ||
"config.SEED_VALUE=0", | ||
"config.LOSS.swav_loss.epsilon=0.03", | ||
f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpu}", | ||
"config.LOG_FREQUENCY=1", | ||
"config.OPTIMIZER.construct_single_param_group_only=True", | ||
], | ||
) | ||
|
||
args, config = convert_to_attrdict(cfg) | ||
if with_fsdp: | ||
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp" | ||
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head_fsdp" | ||
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task" | ||
else: | ||
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2" | ||
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head" | ||
return config | ||
|
||
@staticmethod | ||
def _create_finetuning_config( | ||
checkpoint_path: str, | ||
num_gpu: int = 2, | ||
regularize_bias: bool = False, | ||
construct_single_param_group_only: bool = False, | ||
with_fsdp: bool = False, | ||
): | ||
with initialize_config_module(config_module="vissl.config"): | ||
cfg = compose( | ||
"defaults", | ||
overrides=[ | ||
"config=test/integration_test/quick_eval_finetune_in1k", | ||
"+config/test/integration_test/models=finetune_regnet_fsdp", | ||
f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE={checkpoint_path}", | ||
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]", | ||
"config.DATA.TRAIN.LABEL_SOURCES=[synthetic]", | ||
"config.DATA.TEST.DATA_SOURCES=[synthetic]", | ||
"config.DATA.TEST.LABEL_SOURCES=[synthetic]", | ||
"config.DATA.TRAIN.DATA_LIMIT=40", | ||
"config.DATA.TEST.DATA_LIMIT=20", | ||
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4", | ||
"config.DATA.TEST.BATCHSIZE_PER_REPLICA=2", | ||
"config.SEED_VALUE=0", | ||
f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpu}", | ||
"config.LOG_FREQUENCY=1", | ||
"config.OPTIMIZER.num_epochs=2", | ||
"config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.base_value=0.01", | ||
"config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.base_lr_batch_size=2", | ||
"config.OPTIMIZER.param_schedulers.lr_head.auto_lr_scaling.base_value=0.1", | ||
"config.OPTIMIZER.param_schedulers.lr_head.auto_lr_scaling.base_lr_batch_size=2", | ||
f"config.OPTIMIZER.regularize_bias={regularize_bias}", | ||
f"config.OPTIMIZER.construct_single_param_group_only={construct_single_param_group_only}", | ||
], | ||
) | ||
args, config = convert_to_attrdict(cfg) | ||
if with_fsdp: | ||
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp" | ||
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "mlp_fsdp" | ||
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task" | ||
else: | ||
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2" | ||
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "mlp" | ||
return config | ||
|
||
@gpu_test(gpu_count=1) | ||
def test_get_optimizer_param_groups(self): | ||
finetune_config = self._create_finetuning_config( | ||
checkpoint_path="", | ||
construct_single_param_group_only=False, | ||
regularize_bias=False, | ||
) | ||
optimizer_schedulers = build_optimizer_schedulers(finetune_config["OPTIMIZER"]) | ||
base_model = build_model(finetune_config["MODEL"], finetune_config["OPTIMIZER"]) | ||
param_groups = get_optimizer_param_groups( | ||
model=base_model, | ||
model_config=finetune_config["MODEL"], | ||
optimizer_config=finetune_config["OPTIMIZER"], | ||
optimizer_schedulers=optimizer_schedulers, | ||
) | ||
|
||
expected_param_groups = [ | ||
{ | ||
"params_count": 95, | ||
"params_numel": 80_419_552, | ||
"start_lr": 0.04, | ||
"end_lr": 0.0004, | ||
"weight_decay": 0.0001, | ||
}, | ||
{ | ||
"params_count": 154, | ||
"params_numel": 145_588, | ||
"start_lr": 0.04, | ||
"end_lr": 0.0004, | ||
"weight_decay": 0.0, | ||
}, | ||
{ | ||
# Params for linear layer matrix | ||
"params_count": 1, | ||
"params_numel": 3024 * 1000, | ||
"start_lr": 0.4, | ||
"end_lr": 0.004, | ||
"weight_decay": 1e-6, | ||
}, | ||
{ | ||
# Params for linear layer biases | ||
"params_count": 1, | ||
"params_numel": 1000, | ||
"start_lr": 0.4, | ||
"end_lr": 0.004, | ||
"weight_decay": 0.0, | ||
}, | ||
] | ||
|
||
for i, param_group in enumerate(param_groups): | ||
numel = sum(p.numel() for p in param_group["params"]) | ||
self.assertEqual(set(param_group.keys()), {"params", "lr", "weight_decay"}) | ||
self.assertEqual( | ||
len(param_group["params"]), expected_param_groups[i]["params_count"] | ||
) | ||
self.assertEqual(numel, expected_param_groups[i]["params_numel"]) | ||
self.assertEqual( | ||
param_group["lr"]._start_value, expected_param_groups[i]["start_lr"] | ||
) | ||
self.assertEqual( | ||
param_group["lr"]._end_value, expected_param_groups[i]["end_lr"] | ||
) | ||
self.assertEqual( | ||
param_group["weight_decay"], expected_param_groups[i]["weight_decay"] | ||
) | ||
|
||
@gpu_test(gpu_count=1) | ||
def test_get_optimizer_param_groups_fsdp_single_group(self): | ||
with with_temp_files(count=1) as sync_file: | ||
init_distributed_on_file(world_size=1, gpu_id=0, sync_file=sync_file) | ||
|
||
finetune_config = self._create_finetuning_config( | ||
checkpoint_path="", | ||
construct_single_param_group_only=True, | ||
regularize_bias=False, | ||
with_fsdp=True, | ||
) | ||
optimizer_schedulers = build_optimizer_schedulers( | ||
finetune_config["OPTIMIZER"] | ||
) | ||
base_model = build_model( | ||
finetune_config["MODEL"], finetune_config["OPTIMIZER"] | ||
) | ||
param_groups = get_optimizer_param_groups( | ||
model=base_model, | ||
model_config=finetune_config["MODEL"], | ||
optimizer_config=finetune_config["OPTIMIZER"], | ||
optimizer_schedulers=optimizer_schedulers, | ||
) | ||
|
||
expected_param_groups = [{"param_numel": 83590140}] | ||
|
||
for i, param_group in enumerate(param_groups): | ||
numel = sum(p.numel() for p in param_group["params"]) | ||
self.assertEqual(expected_param_groups[i]["param_numel"], numel) | ||
|
||
@gpu_test(gpu_count=2) | ||
def test_fine_tuning_end_to_end(self): | ||
with in_temporary_directory() as pretrain_dir: | ||
|
||
# Run a pre-training to have some weights to being with | ||
pretrain_config = self._create_pretraining_config() | ||
run_integration_test(pretrain_config) | ||
checkpoint_path = os.path.join(pretrain_dir, "checkpoint.torch") | ||
|
||
# Create a separate directly in which to run the fine-tuning | ||
with in_temporary_directory(): | ||
finetune_config = self._create_finetuning_config( | ||
checkpoint_path, | ||
construct_single_param_group_only=False, | ||
regularize_bias=False, | ||
) | ||
result = run_integration_test(finetune_config) | ||
accuracies = result.get_accuracies(from_metrics_file=True) | ||
self.assertEqual(4, len(accuracies)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.