Skip to content

Commit

Permalink
state checkpointing (facebookresearch#122)
Browse files Browse the repository at this point in the history
Summary:
Linked to issue: fairinternal/ssl_scaling#87

This PR contains:

- added integration tests for the state checkpointing
- a bug correction in `log_hooks.py` that made the restart do one more iteration
- a first factorisation of test utilities with an API for integration tests

Pull Request resolved: fairinternal/ssl_scaling#122

Reviewed By: prigoyal

Differential Revision: D28011417

Pulled By: QuentinDuval

fbshipit-source-id: b16abc32e1dbfc4647ab479559c7beadf7dec89f
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Apr 29, 2021
1 parent 84be0b1 commit 260869c
Show file tree
Hide file tree
Showing 9 changed files with 375 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_
config:
MODEL:
TRUNK:
NAME: regnet # This could be overriden by command line to be regnet_fsdp
REGNET:
# The following is the same as the model "regnet_y_16gf: from ClassyVision/classy_vision/models/regnet.py
depth: 18
w_0: 200
w_a: 106.23
w_m: 2.48
group_width: 112
HEAD:
PARAMS: [
["eval_mlp", {"in_channels": 3024, "dims": [3024, 10]}],
]
1 change: 1 addition & 0 deletions dev/run_quick_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ SRC_DIR=$(dirname "${SRC_DIR}")
TEST_LIST=(
"test_regnet_fsdp.py"
"test_regnet_fsdp_integration.py"
"test_state_checkpointing.py"
)

echo "========================================================================"
Expand Down
54 changes: 12 additions & 42 deletions tests/test_regnet_fsdp_integration.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import re
import shutil
import tempfile
import unittest
from contextlib import contextmanager

import torch
from hydra.experimental import compose, initialize_config_module
from vissl.hooks import default_hook_generator
from vissl.utils.distributed_launcher import launch_distributed
from vissl.utils.hydra_config import convert_to_attrdict
from vissl.utils.test_utils import (
gpu_test,
in_temporary_directory,
run_integration_test,
)


class TestRegnetFSDPIntegration(unittest.TestCase):
Expand Down Expand Up @@ -61,52 +59,24 @@ def _create_pretraining_config(
config.MODEL.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING = (
with_activation_checkpointing
)
return args, config

@staticmethod
@contextmanager
def _in_temporary_directory():
temp_dir = tempfile.mkdtemp()
old_cwd = os.getcwd()
os.chdir(temp_dir)
yield temp_dir
os.chdir(old_cwd)
shutil.rmtree(temp_dir)

def capture_losses(self, file_name: str):
losses = []
regex = re.compile(r"iter: (.*?); lr: (?:.*?); loss: (.*?);")
with open(file_name, "r") as file:
for line in file:
if not line.startswith("INFO"):
continue
match = regex.search(line)
if match is not None:
loss = float(match.group(2))
losses.append(loss)
return losses
return config

def run_pretraining(
self,
with_fsdp: bool,
with_activation_checkpointing: bool,
with_mixed_precision: bool,
):
with self._in_temporary_directory() as dir_name:
args, config = self._create_pretraining_config(
with in_temporary_directory():
config = self._create_pretraining_config(
with_fsdp=with_fsdp,
with_activation_checkpointing=with_activation_checkpointing,
with_mixed_precision=with_mixed_precision,
)
launch_distributed(
cfg=config,
node_id=args.node_id,
engine_name=args.engine_name,
hook_generator=default_hook_generator,
)
return self.capture_losses(os.path.join(dir_name, "log.txt"))
result = run_integration_test(config)
return result.get_losses()

@unittest.skipIf(torch.cuda.device_count() < 2, "Not enough GPUs to run the test")
@gpu_test(gpu_count=2)
def test_fsdp_integration(self):
ddp_losses = self.run_pretraining(
with_fsdp=False,
Expand All @@ -120,7 +90,7 @@ def test_fsdp_integration(self):
)
self.assertEqual(ddp_losses, fsdp_losses)

@unittest.skipIf(torch.cuda.device_count() < 2, "Not enough GPUs to run the test")
@gpu_test(gpu_count=2)
def test_fsdp_integration_mixed_precision(self):
ddp_losses = self.run_pretraining(
with_fsdp=False,
Expand Down
193 changes: 193 additions & 0 deletions tests/test_state_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import unittest

from hydra.experimental import compose, initialize_config_module
from vissl.config.attr_dict import AttrDict
from vissl.utils.hydra_config import convert_to_attrdict
from vissl.utils.test_utils import (
gpu_test,
in_temporary_directory,
run_integration_test,
)


class TestStateCheckpointing(unittest.TestCase):
"""
Check that loading a checkpoint during training works
Check that loading a checkpoint for benchmarking works
"""

@staticmethod
def _create_pretraining_config(with_fsdp: bool):
with initialize_config_module(config_module="vissl.config"):
cfg = compose(
"defaults",
overrides=[
"config=test/integration_test/quick_swav",
"+config/pretrain/swav/models=regnet16Gf",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config.DATA.TRAIN.DATA_LIMIT=40",
"config.SEED_VALUE=0",
"config.MODEL.AMP_PARAMS.USE_AMP=False",
"config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True",
"config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch",
"config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch",
"config.LOSS.swav_loss.epsilon=0.03",
"config.MODEL.FSDP_CONFIG.flatten_parameters=True",
"config.MODEL.FSDP_CONFIG.mixed_precision=False",
"config.MODEL.FSDP_CONFIG.fp32_reduce_scatter=False",
"config.MODEL.FSDP_CONFIG.compute_dtype=float32",
"config.DISTRIBUTED.NUM_PROC_PER_NODE=2",
"config.LOG_FREQUENCY=1",
"config.OPTIMIZER.construct_single_param_group_only=True",
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4",
"config.OPTIMIZER.use_larc=False",
"config.REPRODUCIBILITY.CUDDN_DETERMINISTIC=True",
"config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True",
],
)
args, config = convert_to_attrdict(cfg)
if with_fsdp:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head_fsdp"
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task"
else:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head"
return config

def run_preemption_test(self, config: AttrDict, compare_losses: bool = True):
initial_result = run_integration_test(config)
initial_iters, initial_losses = initial_result.get_losses_with_iterations()

initial_result.clean_final_checkpoint()
initial_result.clean_logs()

restart_result = run_integration_test(config)
restart_iters, restart_losses = restart_result.get_losses_with_iterations()

print("INITIAL:", initial_iters, initial_losses)
print("RESTART:", restart_iters, restart_losses)
self.assertEqual(initial_iters[-len(restart_iters) :], restart_iters)
if compare_losses:
self.assertEqual(initial_losses[-len(restart_losses) :], restart_losses)

@gpu_test(gpu_count=2)
def test_restart_after_preemption_at_epoch(self):
with in_temporary_directory():
config = self._create_pretraining_config(with_fsdp=False)
config.OPTIMIZER.num_epochs = 2
self.run_preemption_test(config)

@gpu_test(gpu_count=2)
def test_restart_after_preemption_at_epoch_fsdp(self):
with in_temporary_directory():
config = self._create_pretraining_config(with_fsdp=True)
config.OPTIMIZER.num_epochs = 2
self.run_preemption_test(config)

@gpu_test(gpu_count=2)
def test_restart_after_preemption_at_iteration(self):
with in_temporary_directory():
config = self._create_pretraining_config(with_fsdp=False)
config.CHECKPOINT.CHECKPOINT_ITER_FREQUENCY = 3
# TODO - understand why the losses do not match exactly for iteration preemption
self.run_preemption_test(config, compare_losses=False)

@gpu_test(gpu_count=2)
def test_restart_after_preemption_at_iteration_fsdp(self):
with in_temporary_directory():
config = self._create_pretraining_config(with_fsdp=True)
config.CHECKPOINT.CHECKPOINT_ITER_FREQUENCY = 3
# TODO - understand why the losses do not match exactly for iteration preemption
self.run_preemption_test(config, compare_losses=False)

@staticmethod
def _create_benchmark_config(checkpoint_path: str, with_fsdp: bool):
with initialize_config_module(config_module="vissl.config"):
cfg = compose(
"defaults",
overrides=[
"config=debugging/benchmark/linear_image_classification/eval_resnet_8gpu_transfer_imagenette_160",
"+config/debugging/benchmark/linear_image_classification/models=regnet16Gf_eval_mlp",
f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE={checkpoint_path}",
"config.SEED_VALUE=2",
"config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch",
"config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch",
"config.OPTIMIZER.num_epochs=1",
"config.OPTIMIZER.param_schedulers.lr.lengths=[0.1, 0.9]",
"config.OPTIMIZER.param_schedulers.lr.name=cosine",
"config.LOSS.swav_loss.epsilon=0.03",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config.DATA.TRAIN.LABEL_SOURCES=[synthetic]",
"config.DATA.TEST.DATA_SOURCES=[synthetic]",
"config.DATA.TEST.LABEL_SOURCES=[synthetic]",
"config.DATA.TRAIN.DATA_LIMIT=40",
"config.DATA.TEST.DATA_LIMIT=16",
"config.DISTRIBUTED.NCCL_DEBUG=False",
"config.MODEL.AMP_PARAMS.USE_AMP=false",
"config.MODEL.FSDP_CONFIG.mixed_precision=false",
"config.OPTIMIZER.use_larc=false",
"config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True", # This is critical
"config.REPRODUCIBILITY.CUDDN_DETERMINISTIC=True",
"config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True",
"config.DATA.TEST.USE_DEBUGGING_SAMPLER=True",
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4",
"config.DATA.TEST.BATCHSIZE_PER_REPLICA=4",
"config.MODEL.FSDP_CONFIG.flatten_parameters=True",
"config.MODEL.FSDP_CONFIG.fp32_reduce_scatter=false",
"config.OPTIMIZER.construct_single_param_group_only=True",
"config.OPTIMIZER.num_epochs=2",
"config.DISTRIBUTED.NUM_NODES=1",
"config.DISTRIBUTED.NUM_PROC_PER_NODE=2",
],
)
args, config = convert_to_attrdict(cfg)
if with_fsdp:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "eval_mlp_fsdp"
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task"
else:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "eval_mlp"
return config

def run_benchmarking(self, checkpoint_path: str, with_fsdp: bool):
with in_temporary_directory():
config = self._create_benchmark_config(checkpoint_path, with_fsdp=with_fsdp)
results = run_integration_test(config)
return results.get_losses(), results.get_accuracies()

@gpu_test(gpu_count=2)
def test_benchmarking_from_a_consolidated_checkpoint(self):
with in_temporary_directory() as checkpoint_folder:
# Run a pre-training in DDP mode and save a consolidated checkpoint
config = self._create_pretraining_config(with_fsdp=False)
run_integration_test(config)
checkpoint_path = os.path.join(checkpoint_folder, "checkpoint.torch")

# Now, run both DDP and FSDP linear evaluation and compare the traces
ddp_losses, ddp_accuracies = self.run_benchmarking(
checkpoint_path, with_fsdp=False
)
fsdp_losses, fsdp_accuracies = self.run_benchmarking(
checkpoint_path, with_fsdp=True
)
self.assertEqual(ddp_losses, fsdp_losses)
self.assertEqual(ddp_accuracies, fsdp_accuracies)

@gpu_test(gpu_count=2)
def test_benchmarking_from_sharded_checkpoint(self):
with in_temporary_directory() as checkpoint_folder:
# Run a pre-training in FSDP mode and save a sharded checkpoing
config = self._create_pretraining_config(with_fsdp=True)
run_integration_test(config)
checkpoint_path = os.path.join(checkpoint_folder, "checkpoint.torch")

# Verify that FSDP can load the checkpoint and run a benchmark on it
fsdp_losses, fsdp_accuracies = self.run_benchmarking(
checkpoint_path, with_fsdp=True
)
self.assertGreaterEqual(len(fsdp_losses), 0)
self.assertEqual(4, len(fsdp_accuracies))
18 changes: 17 additions & 1 deletion tests/test_utils_hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_inference_of_fsdp_settings_for_swav_pretraining(self):
def test_inference_of_fsdp_settings_for_linear_evaluation(self):
overrides = [
"config=debugging/benchmark/linear_image_classification/eval_resnet_8gpu_transfer_imagenette_160",
"+config/debugging/benchmark/linear_image_classification/models=regnet16Gf",
"+config/debugging/benchmark/linear_image_classification/models=regnet16Gf_mlp",
]

cfg = self._create_config(overrides)
Expand All @@ -59,3 +59,19 @@ def test_inference_of_fsdp_settings_for_linear_evaluation(self):
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "mlp_fsdp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet_fsdp")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_fsdp_task")

def test_inference_of_fsdp_settings_for_linear_evaluation_with_bn(self):
overrides = [
"config=debugging/benchmark/linear_image_classification/eval_resnet_8gpu_transfer_imagenette_160",
"+config/debugging/benchmark/linear_image_classification/models=regnet16Gf_eval_mlp",
]

cfg = self._create_config(overrides)
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "eval_mlp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_task")

cfg = self._create_config(overrides + ["config.MODEL.FSDP_CONFIG.AUTO_SETUP_FSDP=True"])
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "eval_mlp_fsdp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet_fsdp")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_fsdp_task")
2 changes: 2 additions & 0 deletions vissl/data/ssl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def _load_labels(self):
labels = np.array(labels).astype(np.int64)
elif label_source == "torchvision_dataset":
labels = np.array(self.data_objs[idx].get_labels()).astype(np.int64)
elif label_source == "synthetic":
labels = np.array([0 for _ in range(len(self.data_objs[idx]))])
else:
raise ValueError(f"unknown label source: {label_source}")
self.label_objs.append(labels)
Expand Down
24 changes: 11 additions & 13 deletions vissl/hooks/log_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _checkpoint_model(
"""
phase_idx = task.phase_idx
num_epochs = task.num_epochs

# check if we need to checkpoint this phase
is_checkpointing_phase = is_checkpoint_phase(
mode_num, mode_frequency, train_phase_idx, num_epochs, mode
Expand Down Expand Up @@ -376,14 +377,21 @@ def _checkpoint_model(
# save the incremented phase_idx as it will incorrectly assume that model
# trained for that phase already.
if mode == "iteration":
phase_idx = phase_idx - 1
model_state_dict["phase_idx"] = model_state_dict["phase_idx"] - 1
if task.train:
train_phase_idx = train_phase_idx - 1
model_state_dict["train_phase_idx"] = train_phase_idx
restart_phase = phase_idx - 1
restart_iteration = task.iteration

# When loading from a phase checkpoint:
else:
restart_phase = phase_idx
restart_iteration = task.iteration

checkpoint_content = {
"phase_idx": phase_idx,
"iteration": task.iteration,
"phase_idx": restart_phase,
"iteration": restart_iteration,
"loss": task.loss.state_dict(),
"iteration_num": task.local_iteration_num,
"train_phase_idx": train_phase_idx,
Expand All @@ -408,16 +416,6 @@ def _checkpoint_model(
else:
checkpoint_writer.save_consolidated_checkpoint(checkpoint_content)

"""
# TODO - remove this
if is_primary() and mode_num == 0:
import subprocess
import submitit
job_id = submitit.JobEnvironment().job_id
subprocess.check_call(["scancel", job_id, "--signal", "TERM"])
"""

def _print_and_save_meters(self, task, train_phase_idx):
"""
Executed only on master gpu at the end of each epoch. Computes the
Expand Down
Loading

0 comments on commit 260869c

Please sign in to comment.