forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix extract features training code (facebookresearch#157)
Summary: Fix the extract features training code and add tests to ensure that it does not regress Pull Request resolved: fairinternal/ssl_scaling#157 Reviewed By: prigoyal Differential Revision: D29232182 Pulled By: QuentinDuval fbshipit-source-id: f66dfea202168d25e577949729e7d28e83296b73
- Loading branch information
1 parent
7ad8eca
commit bb78ee7
Showing
7 changed files
with
212 additions
and
39 deletions.
There are no files selected for viewing
37 changes: 37 additions & 0 deletions
37
configs/config/feature_extraction/dataset/imagenette160.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# @package _global_ | ||
config: | ||
DATA: | ||
NUM_DATALOADER_WORKERS: 5 | ||
TRAIN: | ||
DATA_SOURCES: [disk_folder] | ||
LABEL_SOURCES: [disk_folder] | ||
DATASET_NAMES: [imagenette_160_folder] | ||
BATCHSIZE_PER_REPLICA: 256 | ||
TRANSFORMS: | ||
- name: RandomResizedCrop | ||
size: 128 | ||
- name: RandomHorizontalFlip | ||
- name: ToTensor | ||
- name: Normalize | ||
mean: [0.485, 0.456, 0.406] | ||
std: [0.229, 0.224, 0.225] | ||
MMAP_MODE: True | ||
COPY_TO_LOCAL_DISK: False | ||
COPY_DESTINATION_DIR: /tmp/imagenette_160/ | ||
TEST: | ||
DATA_SOURCES: [disk_folder] | ||
LABEL_SOURCES: [disk_folder] | ||
DATASET_NAMES: [imagenette_160_folder] | ||
BATCHSIZE_PER_REPLICA: 256 | ||
TRANSFORMS: | ||
- name: Resize | ||
size: 160 | ||
- name: CenterCrop | ||
size: 128 | ||
- name: ToTensor | ||
- name: Normalize | ||
mean: [0.485, 0.456, 0.406] | ||
std: [0.229, 0.224, 0.225] | ||
MMAP_MODE: True | ||
COPY_TO_LOCAL_DISK: False | ||
COPY_DESTINATION_DIR: /tmp/imagenette_160/ |
15 changes: 15 additions & 0 deletions
15
configs/config/feature_extraction/with_head/rn50_swav.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _global_ | ||
config: | ||
MODEL: | ||
FEATURE_EVAL_SETTINGS: | ||
EVAL_MODE_ON: True | ||
FREEZE_TRUNK_AND_HEAD: True | ||
EVAL_TRUNK_AND_HEAD: True | ||
TRUNK: | ||
NAME: resnet | ||
RESNETS: | ||
DEPTH: 50 | ||
HEAD: | ||
PARAMS: [ | ||
["swav_head", {"dims": [2048, 2048, 128], "use_bn": True, "num_clusters": [3000]}], | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import os | ||
import unittest | ||
|
||
from hydra.experimental import compose, initialize_config_module | ||
from vissl.utils.hydra_config import convert_to_attrdict | ||
from vissl.utils.test_utils import ( | ||
gpu_test, | ||
in_temporary_directory, | ||
run_integration_test, | ||
) | ||
|
||
|
||
class TestExtractClusterWorkflow(unittest.TestCase): | ||
@staticmethod | ||
def _create_pretraining_config(num_gpu: int = 2): | ||
with initialize_config_module(config_module="vissl.config"): | ||
cfg = compose( | ||
"defaults", | ||
overrides=[ | ||
"config=test/integration_test/quick_swav", | ||
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]", | ||
"config.DATA.TRAIN.DATA_LIMIT=40", | ||
"config.SEED_VALUE=0", | ||
"config.MODEL.AMP_PARAMS.USE_AMP=False", | ||
"config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True", | ||
"config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch", | ||
"config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch", | ||
"config.LOSS.swav_loss.epsilon=0.03", | ||
"config.MODEL.FSDP_CONFIG.flatten_parameters=True", | ||
"config.MODEL.FSDP_CONFIG.mixed_precision=False", | ||
"config.MODEL.FSDP_CONFIG.fp32_reduce_scatter=False", | ||
"config.MODEL.FSDP_CONFIG.compute_dtype=float32", | ||
f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpu}", | ||
"config.LOG_FREQUENCY=1", | ||
"config.OPTIMIZER.construct_single_param_group_only=True", | ||
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4", | ||
"config.OPTIMIZER.use_larc=False", | ||
], | ||
) | ||
|
||
args, config = convert_to_attrdict(cfg) | ||
return config | ||
|
||
@staticmethod | ||
def _create_extract_features_config(checkpoint_path: str, num_gpu: int = 2): | ||
with initialize_config_module(config_module="vissl.config"): | ||
cfg = compose( | ||
"defaults", | ||
overrides=[ | ||
"config=feature_extraction/extract_resnet_in1k_8gpu", | ||
"+config/feature_extraction/with_head=rn50_swav", | ||
f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE={checkpoint_path}", | ||
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]", | ||
"config.DATA.TRAIN.LABEL_SOURCES=[synthetic]", | ||
"config.DATA.TEST.DATA_SOURCES=[synthetic]", | ||
"config.DATA.TEST.LABEL_SOURCES=[synthetic]", | ||
"config.DATA.TRAIN.DATA_LIMIT=40", | ||
"config.DATA.TEST.DATA_LIMIT=40", | ||
"config.SEED_VALUE=0", | ||
"config.MODEL.AMP_PARAMS.USE_AMP=False", | ||
"config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True", | ||
"config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch", | ||
"config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch", | ||
"config.LOSS.swav_loss.epsilon=0.03", | ||
"config.MODEL.FSDP_CONFIG.flatten_parameters=True", | ||
"config.MODEL.FSDP_CONFIG.mixed_precision=False", | ||
"config.MODEL.FSDP_CONFIG.fp32_reduce_scatter=False", | ||
"config.MODEL.FSDP_CONFIG.compute_dtype=float32", | ||
f"config.DISTRIBUTED.NUM_PROC_PER_NODE={num_gpu}", | ||
"config.LOG_FREQUENCY=1", | ||
"config.OPTIMIZER.construct_single_param_group_only=True", | ||
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4", | ||
"config.OPTIMIZER.use_larc=False", | ||
], | ||
) | ||
args, config = convert_to_attrdict(cfg) | ||
return config | ||
|
||
@gpu_test(gpu_count=2) | ||
def test_extract_cluster_assignment_ddp(self): | ||
with in_temporary_directory() as pretrain_dir: | ||
|
||
pretrain_config = self._create_pretraining_config() | ||
run_integration_test(pretrain_config) | ||
|
||
with in_temporary_directory() as extract_dir: | ||
extract_config = self._create_extract_features_config( | ||
checkpoint_path=os.path.join(pretrain_dir, "checkpoint.torch") | ||
) | ||
|
||
run_integration_test(extract_config, engine_name="extract_features") | ||
folder_content = os.listdir(extract_dir) | ||
print(folder_content) | ||
for rank in [0, 1]: | ||
for feat_name in ["heads"]: | ||
for file in [ | ||
f"rank{rank}_train_{feat_name}_features.npy", | ||
f"rank{rank}_train_{feat_name}_inds.npy", | ||
f"rank{rank}_train_{feat_name}_targets.npy", | ||
]: | ||
self.assertIn(file, folder_content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters