From 3d6b4ab12b5a33d70c2aaee400893e8a38f515e8 Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Tue, 10 Mar 2020 20:32:21 -0700 Subject: [PATCH] Add support for Sync BN (#423) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/423 Added support for using sync batch normalization using PyTorch's implementation or Apex's. Plugged in the model complexity hook to `classy_train.py`. It helps test the bug I encountered and fixed which needs the profiler + sync batch norm. Reviewed By: vreis Differential Revision: D20307435 fbshipit-source-id: 45e02d4fced1b4e78c0c45264a6e2fc85825dd3f --- classy_train.py | 3 +- classy_vision/tasks/classification_task.py | 43 +++++++++++++++++++++- test/generic/config_utils.py | 7 ++-- test/trainer_distributed_trainer_test.py | 26 +++++++++++++ 4 files changed, 73 insertions(+), 6 deletions(-) diff --git a/classy_train.py b/classy_train.py index 26e2eb53b8..852627be2c 100755 --- a/classy_train.py +++ b/classy_train.py @@ -50,6 +50,7 @@ from classy_vision.hooks import ( CheckpointHook, LossLrMeterLoggingHook, + ModelComplexityHook, ProfilerHook, ProgressBarHook, TensorboardPlotHook, @@ -118,7 +119,7 @@ def main(args, config): def configure_hooks(args, config): - hooks = [LossLrMeterLoggingHook(args.log_freq)] + hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] # Make a folder to store checkpoints and tensorboard logging outputs suffix = datetime.now().isoformat() diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 24670b1bb9..7f7cff2596 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Union import torch +import torch.nn as nn from classy_vision.dataset import ClassyDataset, build_dataset from classy_vision.generic.distributed_util import ( all_reduce_mean, @@ -53,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum): BEFORE_EVAL = enum.auto() +class BatchNormSyncMode(enum.Enum): + DISABLED = enum.auto() # No Synchronized Batch Normalization + PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm + APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed + + class LastBatchInfo(NamedTuple): loss: torch.Tensor output: torch.Tensor @@ -133,6 +140,7 @@ def __init__(self): self.amp_opt_level = None self.perf_log = [] self.last_batch = None + self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED def set_checkpoint(self, checkpoint): """Sets checkpoint on task. @@ -204,14 +212,35 @@ def set_meters(self, meters: List["ClassyMeter"]): self.meters = meters return self - def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode): + def set_distributed_options( + self, + broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED, + batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED, + ): """Set distributed options. Args: broadcast_buffers_mode: Broadcast buffers mode. See :class:`BroadcastBuffersMode` for options. + batch_norm_sync_mode: Batch normalization synchronization mode. See + :class:`BatchNormSyncMode` for options. + + Raises: + RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex + is not installed. """ self.broadcast_buffers_mode = broadcast_buffers_mode + + if batch_norm_sync_mode == BatchNormSyncMode.DISABLED: + logging.info("Synchronized Batch Normalization is disabled") + else: + if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available: + raise RuntimeError("apex is not installed") + logging.info( + f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}" + ) + self.batch_norm_sync_mode = batch_norm_sync_mode + return self def set_hooks(self, hooks: List["ClassyHook"]): @@ -317,7 +346,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": .set_meters(meters) .set_amp_opt_level(amp_opt_level) .set_distributed_options( - BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")] + broadcast_buffers_mode=BroadcastBuffersMode[ + config.get("broadcast_buffers", "disabled").upper() + ], + batch_norm_sync_mode=BatchNormSyncMode[ + config.get("batch_norm_sync_mode", "disabled").upper() + ], ) ) for phase_type in phase_types: @@ -494,6 +528,11 @@ def prepare( multiprocessing_context=dataloader_mp_context, ) + if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: + self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model) + elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: + self.base_model = apex.parallel.convert_syncbn_model(self.base_model) + # move the model and loss to the right device if use_gpu: self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss) diff --git a/test/generic/config_utils.py b/test/generic/config_utils.py index 2458f3e853..61e32f5da5 100644 --- a/test/generic/config_utils.py +++ b/test/generic/config_utils.py @@ -176,9 +176,9 @@ def get_test_mlp_task_config(): "num_classes": 2, "crop_size": 20, "class_ratio": 0.5, - "num_samples": 10, + "num_samples": 20, "seed": 0, - "batchsize_per_replica": 3, + "batchsize_per_replica": 4, "use_augmentation": False, "use_shuffle": True, "transforms": [ @@ -201,7 +201,7 @@ def get_test_mlp_task_config(): "num_classes": 2, "crop_size": 20, "class_ratio": 0.5, - "num_samples": 10, + "num_samples": 20, "seed": 0, "batchsize_per_replica": 1, "use_augmentation": False, @@ -228,6 +228,7 @@ def get_test_mlp_task_config(): "input_dim": 1200, "output_dim": 1000, "hidden_dims": [10], + "use_batchnorm": True, # used for testing sync batchnorm }, "meters": {"accuracy": {"topk": [1]}}, "optimizer": { diff --git a/test/trainer_distributed_trainer_test.py b/test/trainer_distributed_trainer_test.py index e3950d74b3..dd5c1124f2 100644 --- a/test/trainer_distributed_trainer_test.py +++ b/test/trainer_distributed_trainer_test.py @@ -22,10 +22,13 @@ def setUp(self): config = get_test_mlp_task_config() invalid_config = copy.deepcopy(config) invalid_config["name"] = "invalid_task" + sync_bn_config = copy.deepcopy(config) + sync_bn_config["sync_batch_norm_mode"] = "pytorch" self.config_files = {} for config_key, config in [ ("config", config), ("invalid_config", invalid_config), + ("sync_bn_config", sync_bn_config), ]: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: json.dump(config, f) @@ -63,3 +66,26 @@ def test_training(self): result = subprocess.run(cmd, shell=True) success = result.returncode == 0 self.assertEqual(success, expected_success) + + @unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run") + def test_sync_batch_norm(self): + """Test that sync batch norm training doesn't hang.""" + + num_processes = 2 + device = "gpu" + + cmd = f"""{sys.executable} -m torch.distributed.launch \ + --nnodes=1 \ + --nproc_per_node={num_processes} \ + --master_addr=localhost \ + --master_port=29500 \ + --use_env \ + {self.path}/../classy_train.py \ + --device={device} \ + --config={self.config_files["sync_bn_config"]} \ + --num_workers=4 \ + --log_freq=100 \ + --distributed_backend=ddp + """ + result = subprocess.run(cmd, shell=True) + self.assertEqual(result.returncode, 0)