From c5068d753a61b456563a6850c61739feecf9934a Mon Sep 17 00:00:00 2001 From: QuentinDuval Date: Tue, 15 Mar 2022 05:23:25 -0700 Subject: [PATCH] Upgrade of fairscale (#231) Summary: Upgrading to the latest version of fairscale: - [x] Upgrade of PyTorch version to 1.8.1, minimal versions needed for the latest fairscale - [x] Fixed issued with FSDP paths having changed by making our code insensitive to it - [x] Added unit test to check that we can load checkpoints of previous fairscale versions X-link: https://github.com/fairinternal/ssl_scaling/pull/231 Reviewed By: iseessel Differential Revision: D34730366 Pulled By: QuentinDuval fbshipit-source-id: 44759cc388fe7fc1eeca4f570a31ee3ad78e5fe5 --- .circleci/config.yml | 5 +- INSTALL.md | 18 +++--- .../pretrain/swav/swav_8node_resnet.yaml | 5 -- .../integration_test/quick_swav_2crops.yaml | 2 - tests/test_layer_memory_tracking.py | 28 ++++++++-- tests/test_regnet_fsdp_10b.py | 55 +++++++++++++++++++ vissl/hooks/state_update_hooks.py | 19 +++++-- vissl/utils/fsdp_utils.py | 4 +- 8 files changed, 105 insertions(+), 31 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 0b83b4f1b..454f44c7e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -60,7 +60,7 @@ install_fairscale: &install_fairscale working_directory: ~/ command: | pip uninstall -y fairscale - pip install fairscale@https://github.com/facebookresearch/fairscale/tarball/df7db85cef7f9c30a5b821007754b96eb1f977b6 + pip install fairscale==0.4.6 install_classy_vision: &install_classy_vision @@ -112,7 +112,8 @@ install_vissl_dep: &install_vissl_dep name: Install Dependencies working_directory: ~/vissl command: | - pip install --progress-bar off torch==1.7.1 torchvision==0.8.2 opencv-python==3.4.2.17 + pip install --progress-bar off torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html + pip install --progress-bar off opencv-python==3.4.2.17 pip install --progress-bar off -r requirements.txt # Update this since classy_vision seems to need it. pip install --progress-bar off --upgrade iopath diff --git a/INSTALL.md b/INSTALL.md index a6fb666be..e24cc7299 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -33,14 +33,14 @@ The following instructions assume that you have desired CUDA version installed a If you don't have anaconda, [run this bash scrip to install conda](https://github.com/facebookresearch/vissl/blob/main/docker/common/install_conda.sh). ```bash -conda create -n vissl_env python=3.7 +conda create -n vissl_env python=3.8 source activate vissl_env ``` #### Step 2: Install PyTorch (conda) ```bash -conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch +conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=10.2 -c pytorch ``` #### Step 3: Install APEX (conda) @@ -67,7 +67,7 @@ pip uninstall -y classy_vision pip install classy-vision@https://github.com/facebookresearch/ClassyVision/tarball/4785d5ee19d3bcedd5b28c1eb51ea1f59188b54d # update fairscale install to commit stable for vissl. pip uninstall -y fairscale -pip install fairscale@https://github.com/facebookresearch/fairscale/tarball/df7db85cef7f9c30a5b821007754b96eb1f977b6 +pip install fairscale==0.4.6 # install vissl dev mode (e stands for editable) pip install -e ".[dev]" # verify installation @@ -85,13 +85,13 @@ python3 -m venv ~/venv #### Step 2: Install PyTorch (pip) ```bash -pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html +pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html ``` #### Step 3: Install APEX (pip) ```bash -pip install -f https://dl.fbaipublicfiles.com/vissl/packaging/apexwheels/py37_cu101_pyt171/download.html apex +pip install -f https://dl.fbaipublicfiles.com/vissl/packaging/apexwheels/py38_cu102_pyt181/download.html apex ``` #### Step 4: Install VISSL @@ -106,7 +106,7 @@ This assumes you have CUDA 10.2. ```bash conda create -n vissl python=3.8 conda activate vissl -conda install -c pytorch pytorch=1.7.1 torchvision cudatoolkit=10.2 +conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=10.2 -c pytorch conda install -c vissl -c iopath -c conda-forge -c pytorch -c defaults apex vissl ``` @@ -127,14 +127,14 @@ python3 -m venv ~/venv #### Step 2: Install PyTorch, OpenCV and APEX (pip) -- We use PyTorch=1.5.1 with CUDA 10.1 in the following instruction (user can chose their desired version). +- We use PyTorch=1.8.1 with CUDA 10.2 in the following instruction (you can chose your desired version). - There are several ways to install opencv, one possibility is as follows. - For APEX, we provide pre-built binary built with optimized C++/CUDA extensions provided by APEX. ```bash -pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html +pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html pip install opencv-python -pip install -f https://dl.fbaipublicfiles.com/vissl/packaging/apexwheels/py38_cu101_pyt151/download.html apex +pip install -f https://dl.fbaipublicfiles.com/vissl/packaging/apexwheels/py38_cu102_pyt181/download.html apex ``` Note that, for the APEX install, you need to get the versions of CUDA, PyTorch, and Python correct in the URL. We provide APEX versions with all possible combinations of Python, PyTorch, CUDA. Select the right APEX Wheels if you desire a different combination. diff --git a/configs/config/pretrain/swav/swav_8node_resnet.yaml b/configs/config/pretrain/swav/swav_8node_resnet.yaml index eda05410c..d9b1d3d4e 100644 --- a/configs/config/pretrain/swav/swav_8node_resnet.yaml +++ b/configs/config/pretrain/swav/swav_8node_resnet.yaml @@ -55,11 +55,6 @@ config: ] TEMP_FROZEN_PARAMS_ITER_MAP: [ ['module.heads.0.prototypes0.weight', 313], - # TODO (Min): FSDP need to return the original param name from named_parameters(). - # Configuration for flatten_parameters = True - ['_fsdp_wrapped_module.heads.0._fsdp_wrapped_module._fpw_module.prototypes0._fsdp_wrapped_module.weight', 313], - # Configuration for flatten_parameters = False - ['_fsdp_wrapped_module.heads.0._fsdp_wrapped_module.prototypes0._fsdp_wrapped_module.weight', 313] ] SYNC_BN_CONFIG: CONVERT_BN_TO_SYNC_BN: True diff --git a/configs/config/test/integration_test/quick_swav_2crops.yaml b/configs/config/test/integration_test/quick_swav_2crops.yaml index be9844e7f..8cf6ed70f 100644 --- a/configs/config/test/integration_test/quick_swav_2crops.yaml +++ b/configs/config/test/integration_test/quick_swav_2crops.yaml @@ -57,8 +57,6 @@ config: ] TEMP_FROZEN_PARAMS_ITER_MAP: [ ['module.heads.0.prototypes0.weight', 313], - # TODO (Min): FSDP need to return the original param name from named_parameters(). - ['_fsdp_wrapped_module.heads.0._fsdp_wrapped_module._fpw_module.prototypes0._fsdp_wrapped_module.weight', 313] ] SYNC_BN_CONFIG: CONVERT_BN_TO_SYNC_BN: True diff --git a/tests/test_layer_memory_tracking.py b/tests/test_layer_memory_tracking.py index ca9522354..8d4f82659 100644 --- a/tests/test_layer_memory_tracking.py +++ b/tests/test_layer_memory_tracking.py @@ -139,12 +139,28 @@ def _layer_memory_tracking_worker(gpu_id: int, sync_file: str, world_size: int): if t.all_gathered > 0 ] assert all_gathered_traces == [ - ("_fsdp_wrapped_module.0", 440, 440), - ("_fsdp_wrapped_module.2._fsdp_wrapped_module", 440, 880), - ("_fsdp_wrapped_module.4._fsdp_wrapped_module._fpw_module", 440, 880), - ("_fsdp_wrapped_module.4._fsdp_wrapped_module._fpw_module", 440, 0), - ("_fsdp_wrapped_module.2._fsdp_wrapped_module", 440, 0), - ] + ("_fsdp_wrapped_module._fpw_module.0", 440, 440), + ( + "_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", + 440, + 880, + ), + ( + "_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", + 440, + 880, + ), + ( + "_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", + 440, + 0, + ), + ( + "_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", + 440, + 0, + ), + ], f"Expected {all_gathered_traces}" @gpu_test(gpu_count=2) def test_memory_tracking_fsdp(self): diff --git a/tests/test_regnet_fsdp_10b.py b/tests/test_regnet_fsdp_10b.py index 66de6996a..ab720f636 100644 --- a/tests/test_regnet_fsdp_10b.py +++ b/tests/test_regnet_fsdp_10b.py @@ -14,6 +14,13 @@ class TestRegnet10B(unittest.TestCase): + """ + Integrations tests that should be run on 8 GPUs nodes. + + Tests that the RegNet10B trained for SEER still works: + https://arxiv.org/abs/2202.08360 + """ + @staticmethod def _create_10B_pretrain_config(num_gpus: int, num_steps: int, batch_size: int): data_limit = num_steps * batch_size * num_gpus @@ -50,3 +57,51 @@ def test_regnet_10b_swav_pretraining(self): losses = results.get_losses() print(losses) self.assertEqual(len(losses), 2) + + @staticmethod + def _create_10B_evaluation_config( + num_gpus: int, num_steps: int, batch_size: int, path_to_sliced_checkpoint: str + ): + data_limit = num_steps * batch_size * num_gpus + cfg = compose_hydra_configuration( + [ + "config=benchmark/linear_image_classification/clevr_count/eval_resnet_8gpu_transfer_clevr_count_linear", + "+config/benchmark/linear_image_classification/clevr_count/models=regnet10B", + f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE={path_to_sliced_checkpoint}", + "config.MODEL.AMP_PARAMS.USE_AMP=True", + "config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch", + "config.OPTIMIZER.num_epochs=1", + "config.LOG_FREQUENCY=1", + # Testing on fake images + "config.DATA.TRAIN.DATA_SOURCES=[synthetic]", + "config.DATA.TRAIN.LABEL_SOURCES=[synthetic]", + "config.DATA.TRAIN.RANDOM_SYNTHETIC_IMAGES=True", + "config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True", + "config.DATA.TEST.DATA_SOURCES=[synthetic]", + "config.DATA.TEST.LABEL_SOURCES=[synthetic]", + "config.DATA.TEST.RANDOM_SYNTHETIC_IMAGES=True", + "config.DATA.TEST.USE_DEBUGGING_SAMPLER=True", + # Disable overlap communication and computation for test + "config.MODEL.FSDP_CONFIG.FORCE_SYNC_CUDA=True", + # Testing on 8 V100 32GB GPU only + f"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA={batch_size}", + f"config.DATA.TRAIN.DATA_LIMIT={data_limit}", + "config.DISTRIBUTED.NUM_NODES=1", + f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpus}", + "config.DISTRIBUTED.RUN_ID=auto", + ] + ) + args, config = convert_to_attrdict(cfg) + return config + + @gpu_test(gpu_count=8) + def test_regnet_10b_evaluation(self): + with in_temporary_directory(): + cp_path = "/checkpoint/qduval/vissl/seer/regnet10B_sliced/model_iteration124500_sliced.torch" + config = self._create_10B_evaluation_config( + num_gpus=8, num_steps=2, batch_size=4, path_to_sliced_checkpoint=cp_path + ) + results = run_integration_test(config) + losses = results.get_losses() + print(losses) + self.assertGreater(len(losses), 0) diff --git a/vissl/hooks/state_update_hooks.py b/vissl/hooks/state_update_hooks.py index 347c7ab82..8dd3dee6a 100644 --- a/vissl/hooks/state_update_hooks.py +++ b/vissl/hooks/state_update_hooks.py @@ -247,21 +247,28 @@ def on_backward(self, task: "tasks.ClassyTask") -> None: ) return - world_size = ( - task.config.DISTRIBUTED.NUM_NODES - * task.config.DISTRIBUTED.NUM_PROC_PER_NODE - ) - match_param_prefix = "module." if world_size == 1 else "" num_matched_named_params = 0 for name, p in task.model.named_parameters(): - match_param_name = f"{match_param_prefix}{name}" + match_param_name = self._clean_param_path(name) if ( match_param_name in map_params_to_iters ) and task.iteration < map_params_to_iters[match_param_name]: num_matched_named_params += 1 p.grad = None + # TODO (Min): we need to check the exact target number. assert num_matched_named_params > 0, ( f"Didn't find expected number of layers: " f"{num_matched_named_params} vs. {len(map_params_to_iters)}" ) + + @staticmethod + def _clean_param_path(param_name: str) -> str: + # Remove FSDP path artifacts + paths_to_remove = ["_fsdp_wrapped_module.", "_fpw_module."] + for path_to_remove in paths_to_remove: + param_name = param_name.replace(path_to_remove, "") + # Add the missing "module." prefix if missing (DDP prefix) + if not param_name.startswith("module."): + param_name = f"module.{param_name}" + return param_name diff --git a/vissl/utils/fsdp_utils.py b/vissl/utils/fsdp_utils.py index a1234c95d..c435882b0 100644 --- a/vissl/utils/fsdp_utils.py +++ b/vissl/utils/fsdp_utils.py @@ -106,7 +106,9 @@ class _BigConvAutoWrapPolicy: def __init__(self, threshold: int): self.threshold = threshold - def __call__(self, module: nn.Module, recurse: bool, unwrapped_params: int): + def __call__( + self, module: nn.Module, recurse: bool, unwrapped_params: int, **kwargs + ): is_large = unwrapped_params >= self.threshold force_leaf_modules = default_auto_wrap_policy.FORCE_LEAF_MODULES if recurse: