forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
111 lines (87 loc) · 3.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import re
import sys
from typing import Any, List
import pkg_resources
from hydra.experimental import compose, initialize_config_module
from omegaconf import OmegaConf
logger = logging.getLogger("vissl")
# List all the config files, used to generate the unit tests on the fly
def list_config_files(dir_path, exclude_folders):
resource_name = "configs"
assert pkg_resources.resource_isdir(resource_name, dir_path)
all_items = pkg_resources.resource_listdir(resource_name, dir_path)
config_files = []
def valid_file(filename):
if not filename.endswith("yaml"):
return False
if exclude_folders and any(x in filename for x in exclude_folders):
return False
return True
for item in all_items:
subpath = f"{dir_path}/{item}"
if pkg_resources.resource_isdir(resource_name, subpath):
# Recursively test all the tree
config_files.extend(list_config_files(subpath, exclude_folders))
if valid_file(subpath):
# If valid leaf, return the config file
config_files.append(subpath)
return config_files
def create_valid_input(input_list):
out_list = []
for item in input_list:
out_list.append(re.sub("config/", "config=", item))
return out_list
# we skip object detection configs since they are for detectron2 codebase
BENCHMARK_CONFIGS = create_valid_input(
list_config_files("config/benchmark", exclude_folders=["object_detection"])
)
PRETRAIN_CONFIGS = create_valid_input(
list_config_files("config/pretrain", exclude_folders=None)
)
INTEGRATION_TEST_CONFIGS = create_valid_input(
list_config_files("config/test/integration_test", exclude_folders=None)
)
ROOT_CONFIGS = create_valid_input(
list_config_files(
"config", exclude_folders=["models", "optimization", "object_detection"]
)
)
ROOT_OSS_CONFIGS = create_valid_input(
list_config_files(
"config", exclude_folders=["models", "optimization", "object_detection", "fb"]
)
)
# configs that require loss optimization and hence trainable
ROOT_LOSS_CONFIGS = create_valid_input(
list_config_files(
"config",
exclude_folders=[
"models",
"optimization",
"object_detection",
"nearest_neighbor",
"feature_extraction",
"fb",
],
)
)
UNIT_TEST_CONFIGS = create_valid_input(
list_config_files("config/test/cpu_test", exclude_folders=None)
)
initialize_config_module(config_module="vissl.config")
class SSLHydraConfig(object):
def __init__(self, overrides: List[Any] = None):
self.overrides = []
if overrides is not None and len(overrides) > 0:
self.overrides.extend(overrides)
cfg = compose(config_name="defaults", overrides=self.overrides)
self.default_cfg = cfg
@classmethod
def from_configs(cls, config_files: List[Any] = None):
return cls(config_files)
def override(self, config_files: List[Any]):
sys.argv = config_files
cli_conf = OmegaConf.from_cli(config_files)
self.default_cfg = OmegaConf.merge(self.default_cfg, cli_conf)