Skip to content

Commit

Permalink
add tests on the creation of the optimizer group (facebookresearch#178)
Browse files Browse the repository at this point in the history
Summary:
In prevision for the work on the optimiser group enhancement for FSDP:
- These tests will allow to catch some accidental issues that arise if refactoring is needed.
- These tests add an important use case to our list of integration tests: fine tuning.

Pull Request resolved: fairinternal/ssl_scaling#178

Reviewed By: prigoyal

Differential Revision: D30233695

Pulled By: QuentinDuval

fbshipit-source-id: ca67bc5f65328976298c3b09b17048dd4bfde934
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Aug 12, 2021
1 parent b780610 commit 1f08d77
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
config:
MODEL:
TRUNK:
NAME: regnet
REGNET:
name: anynet
depths: [2, 4, 11, 1]
widths: [224, 448, 1232, 3024]
group_widths: [112, 112, 112, 112]
bottleneck_multipliers: [1.0, 1.0, 1.0, 1.0]
strides: [2, 2, 2, 2]
HEAD:
PARAMS: [
["mlp", {"dims": [3024, 1000]}],
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: pytorch
117 changes: 117 additions & 0 deletions configs/config/test/integration_test/quick_eval_finetune_in1k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# @package _global_
config:
VERBOSE: True
LOG_FREQUENCY: 100
TEST_ONLY: False
TEST_EVERY_NUM_EPOCH: 1
TEST_MODEL: True
SEED_VALUE: 1
MULTI_PROCESSING_METHOD: forkserver
HOOKS:
PERF_STATS:
MONITOR_PERF_STATS: True
DATA:
NUM_DATALOADER_WORKERS: 5
TRAIN:
DATA_SOURCES: [disk_folder]
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_folder]
BATCHSIZE_PER_REPLICA: 32
TRANSFORMS:
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
MMAP_MODE: True
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenet1k/
TEST:
DATA_SOURCES: [disk_folder]
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_folder]
BATCHSIZE_PER_REPLICA: 32
TRANSFORMS:
- name: Resize
size: 256
- name: CenterCrop
size: 224
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
MMAP_MODE: True
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenet1k/
METERS:
name: accuracy_list_meter
accuracy_list_meter:
num_meters: 1
topk_values: [1, 5]
TRAINER:
TRAIN_STEP_NAME: standard_train_step
MODEL:
FEATURE_EVAL_SETTINGS:
EVAL_MODE_ON: True
EVAL_TRUNK_AND_HEAD: False
TRUNK:
NAME: resnet
RESNETS:
DEPTH: 50
HEAD:
PARAMS: [
["mlp", {"dims": [2048, 1000]}],
]
WEIGHTS_INIT:
PARAMS_FILE: "specify model weights"
STATE_DICT_KEY_NAME: classy_state_dict
APPEND_PREFIX: trunk.
LOSS:
name: cross_entropy_multiple_output_single_target
cross_entropy_multiple_output_single_target:
ignore_index: -1
OPTIMIZER:
name: sgd
weight_decay: 0.0001
momentum: 0.9
num_epochs: 10
nesterov: True
regularize_bn: False
regularize_bias: True
head_optimizer_params:
use_different_lr: True
use_different_wd: True
weight_decay: 0.000001
param_schedulers:
lr:
auto_lr_scaling:
auto_scale: true
base_value: 0.1
base_lr_batch_size: 256
name: cosine
start_value: 0.01
end_value: 0.0001
update_interval: epoch
lr_head:
auto_lr_scaling:
auto_scale: true
base_value: 0.1
base_lr_batch_size: 256
name: cosine
start_value: 0.1
end_value: 0.001
update_interval: epoch
DISTRIBUTED:
BACKEND: nccl
NUM_NODES: 1
NUM_PROC_PER_NODE: 8
INIT_METHOD: tcp
RUN_ID: auto
MACHINE:
DEVICE: gpu
CHECKPOINT:
DIR: "."
AUTO_RESUME: True
CHECKPOINT_FREQUENCY: 1
1 change: 1 addition & 0 deletions dev/run_quick_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ SRC_DIR=$(dirname "${SRC_DIR}")
TEST_LIST=(
"test_extract_cluster.py"
"test_extract_features.py"
"test_finetuning.py"
"test_larc_fsdp.py"
"test_layer_memory_tracking.py"
"test_losses_gpu.py"
Expand Down
211 changes: 211 additions & 0 deletions tests/test_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest

from classy_vision.optim import build_optimizer_schedulers
from hydra.experimental import compose, initialize_config_module
from vissl.models import build_model
from vissl.optimizers import get_optimizer_param_groups
from vissl.utils.hydra_config import convert_to_attrdict
from vissl.utils.test_utils import (
gpu_test,
in_temporary_directory,
init_distributed_on_file,
run_integration_test,
with_temp_files,
)


class TestFineTuning(unittest.TestCase):
@staticmethod
def _create_pretraining_config(num_gpu: int = 2, with_fsdp: bool = False):
with initialize_config_module(config_module="vissl.config"):
cfg = compose(
"defaults",
overrides=[
"config=test/integration_test/quick_swav",
"+config/test/integration_test/models=swav_regnet_fsdp",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config.DATA.TRAIN.DATA_LIMIT=40",
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4",
"config.SEED_VALUE=0",
"config.LOSS.swav_loss.epsilon=0.03",
f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpu}",
"config.LOG_FREQUENCY=1",
"config.OPTIMIZER.construct_single_param_group_only=True",
],
)

args, config = convert_to_attrdict(cfg)
if with_fsdp:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head_fsdp"
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task"
else:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head"
return config

@staticmethod
def _create_finetuning_config(
checkpoint_path: str,
num_gpu: int = 2,
regularize_bias: bool = False,
construct_single_param_group_only: bool = False,
with_fsdp: bool = False,
):
with initialize_config_module(config_module="vissl.config"):
cfg = compose(
"defaults",
overrides=[
"config=test/integration_test/quick_eval_finetune_in1k",
"+config/test/integration_test/models=finetune_regnet_fsdp",
f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE={checkpoint_path}",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config.DATA.TRAIN.LABEL_SOURCES=[synthetic]",
"config.DATA.TEST.DATA_SOURCES=[synthetic]",
"config.DATA.TEST.LABEL_SOURCES=[synthetic]",
"config.DATA.TRAIN.DATA_LIMIT=40",
"config.DATA.TEST.DATA_LIMIT=20",
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4",
"config.DATA.TEST.BATCHSIZE_PER_REPLICA=2",
"config.SEED_VALUE=0",
f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpu}",
"config.LOG_FREQUENCY=1",
"config.OPTIMIZER.num_epochs=2",
"config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.base_value=0.01",
"config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.base_lr_batch_size=2",
"config.OPTIMIZER.param_schedulers.lr_head.auto_lr_scaling.base_value=0.1",
"config.OPTIMIZER.param_schedulers.lr_head.auto_lr_scaling.base_lr_batch_size=2",
f"config.OPTIMIZER.regularize_bias={regularize_bias}",
f"config.OPTIMIZER.construct_single_param_group_only={construct_single_param_group_only}",
],
)
args, config = convert_to_attrdict(cfg)
if with_fsdp:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "mlp_fsdp"
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task"
else:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "mlp"
return config

@gpu_test(gpu_count=1)
def test_get_optimizer_param_groups(self):
finetune_config = self._create_finetuning_config(
checkpoint_path="",
construct_single_param_group_only=False,
regularize_bias=False,
)
optimizer_schedulers = build_optimizer_schedulers(finetune_config["OPTIMIZER"])
base_model = build_model(finetune_config["MODEL"], finetune_config["OPTIMIZER"])
param_groups = get_optimizer_param_groups(
model=base_model,
model_config=finetune_config["MODEL"],
optimizer_config=finetune_config["OPTIMIZER"],
optimizer_schedulers=optimizer_schedulers,
)

expected_param_groups = [
{
"params_count": 95,
"params_numel": 80_419_552,
"start_lr": 0.04,
"end_lr": 0.0004,
"weight_decay": 0.0001,
},
{
"params_count": 154,
"params_numel": 145_588,
"start_lr": 0.04,
"end_lr": 0.0004,
"weight_decay": 0.0,
},
{
# Params for linear layer matrix
"params_count": 1,
"params_numel": 3024 * 1000,
"start_lr": 0.4,
"end_lr": 0.004,
"weight_decay": 1e-6,
},
{
# Params for linear layer biases
"params_count": 1,
"params_numel": 1000,
"start_lr": 0.4,
"end_lr": 0.004,
"weight_decay": 0.0,
},
]

for i, param_group in enumerate(param_groups):
numel = sum(p.numel() for p in param_group["params"])
self.assertEqual(set(param_group.keys()), {"params", "lr", "weight_decay"})
self.assertEqual(
len(param_group["params"]), expected_param_groups[i]["params_count"]
)
self.assertEqual(numel, expected_param_groups[i]["params_numel"])
self.assertEqual(
param_group["lr"]._start_value, expected_param_groups[i]["start_lr"]
)
self.assertEqual(
param_group["lr"]._end_value, expected_param_groups[i]["end_lr"]
)
self.assertEqual(
param_group["weight_decay"], expected_param_groups[i]["weight_decay"]
)

@gpu_test(gpu_count=1)
def test_get_optimizer_param_groups_fsdp_single_group(self):
with with_temp_files(count=1) as sync_file:
init_distributed_on_file(world_size=1, gpu_id=0, sync_file=sync_file)

finetune_config = self._create_finetuning_config(
checkpoint_path="",
construct_single_param_group_only=True,
regularize_bias=False,
with_fsdp=True,
)
optimizer_schedulers = build_optimizer_schedulers(
finetune_config["OPTIMIZER"]
)
base_model = build_model(
finetune_config["MODEL"], finetune_config["OPTIMIZER"]
)
param_groups = get_optimizer_param_groups(
model=base_model,
model_config=finetune_config["MODEL"],
optimizer_config=finetune_config["OPTIMIZER"],
optimizer_schedulers=optimizer_schedulers,
)

expected_param_groups = [{"param_numel": 83590140}]

for i, param_group in enumerate(param_groups):
numel = sum(p.numel() for p in param_group["params"])
self.assertEqual(expected_param_groups[i]["param_numel"], numel)

@gpu_test(gpu_count=2)
def test_fine_tuning_end_to_end(self):
with in_temporary_directory() as pretrain_dir:

# Run a pre-training to have some weights to being with
pretrain_config = self._create_pretraining_config()
run_integration_test(pretrain_config)
checkpoint_path = os.path.join(pretrain_dir, "checkpoint.torch")

# Create a separate directly in which to run the fine-tuning
with in_temporary_directory():
finetune_config = self._create_finetuning_config(
checkpoint_path,
construct_single_param_group_only=False,
regularize_bias=False,
)
result = run_integration_test(finetune_config)
accuracies = result.get_accuracies(from_metrics_file=True)
self.assertEqual(4, len(accuracies))
2 changes: 1 addition & 1 deletion vissl/hooks/log_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def on_update(self, task: "tasks.ClassyTask") -> None:

eta_secs = avg_time * (task.max_iteration - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_secs)))
if isinstance(task.optimizer.options_view.lr, set):
if isinstance(task.optimizer.options_view.lr, (set, list)):
lr_val = list(task.optimizer.options_view.lr)
else:
lr_val = round(task.optimizer.options_view.lr, 5)
Expand Down
Loading

0 comments on commit 1f08d77

Please sign in to comment.