-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
179 lines (143 loc) · 5.66 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Python wrapper for experiment configurations.
Original code:
https://github.com/naver-ai/pcme/blob/main/config/config.py
"""
import os
import pickle
import munch
try:
import ujson as json
except ImportError:
import json
import yaml
from yaml.error import YAMLError
try:
import torch
except ImportError:
pass
def _print(msg, verbose):
"""A simple print wrapper.
"""
if verbose:
print(msg)
def _loader(config_path: str,
verbose: bool = False) -> dict:
"""A simple serializer loader.
Examples
--------
>>> _loader('test.json')
"""
with open(config_path, 'r') as fin:
try:
return yaml.safe_load(fin)
except YAMLError:
_print('failed to load from yaml. Try pickle loader', verbose)
with open(config_path, 'rb') as fin:
try:
return pickle.load(fin)
except TypeError:
_print('failed to load from pickle. Try torch loader', verbose)
try:
return torch.load(fin)
except TypeError:
_print('failed to load from pickle. Please check your configuration again.', verbose)
raise TypeError('config_path should be serialized by [yaml, json, pickle, torch pth]')
def dump_config(config: munch.Munch,
dump_to: str,
overwrite: bool = False,
serializer: str = 'json'):
"""Dump the configuration to the local file {dump_to}.
Parameters
----------
config: munch.Munch
A configuration file defines the structure of the configuration.
The file should be serialized by any of [yaml, json, pickle, torch].
dump_to: str
A destination path to dump the configuration.
overwrite: bool, optional, default=False,
If False, raise FileExistsError if `dump_to` already exists.
serializer: str, optional, default='json',
Format to dump. It should be in ["json", "yaml", "pickle", "torch"]
Examples
--------
>>> dump_config(config, 'my_simple_config.json')
"""
if serializer not in {'json', 'yaml', 'pickle', 'torch'}:
raise ValueError(f'format should be in ["json", "yaml", "pickle", "torch"], not {serializer}')
if os.path.exists(dump_to) and not overwrite:
raise FileExistsError(dump_to)
if serializer in ('pickle', 'torch'):
mode = 'wb'
else:
mode = 'w'
config = munch.unmunchify(config)
with open(dump_to, mode) as fout:
if serializer == 'json':
json.dump(config, fout, indent=4, sort_keys=True)
elif serializer == 'yaml':
yaml.dump(config, fout)
elif serializer == 'pickle':
pickle.dump(config, fout)
elif serializer == 'torch':
torch.save(config, fout)
def parse_config(config_fname: str,
delimiter: str = '__',
strict_cast: bool = True,
verbose: bool = False,
**kwargs) -> munch.Munch:
"""Parse the given configuration file with additional options to overwrite.
Parameters
----------
config_fname: str
A configuration file defines the structure of the configuration.
The file should be serialized by any of [yaml, json, pickle, torch].
delimiter: str, optional, default='__'
A delimiter for the additional kwargs configuration.
See kwargs for more information.
strict_cast: bool, optional, default=True
If True, the overwritten config values will be casted as the original type.
verbose: bool, optional, default=False
kwargs: optional
If specified, overwrite the current configuration by the given keywords.
For the multi-depth configuration, "__" is used for the default delimiter.
The keys in kwargs should be already defined by config_fname (otherwise it will raise KeyError).
Note that if `strict_cast` is True, the values in kwargs will be casted as the original type defined in the configuration file.
Returns
-------
config: munch.Munch
A configuration file, which provides attribute-style access.
See `Munch <https://github.com/Infinidat/munch>`_ project for the details.
Examples
--------
>>> # simple_config.json => {"opt1": {"opt2": 1}, "opt3": 0}
>>> config = parse_config('simple_config.json')
>>> print(config.opt1.opt2, config.opt3, type(config.opt1.opt2), type(config.opt3))
2 1 <class 'int'> <class 'int'>
>>> config = parse_config('simple_config.json', opt1__opt2=2, opt3=1)
>>> print(config.opt1.opt2, config.opt3, type(config.opt1.opt2), type(config.opt3))
2 1 <class 'int'> <class 'int'>
>>> parse_config('test.json', **{'opt1__opt2': '2', 'opt3': 1.0})
>>> print(config.opt1.opt2, config.opt3, type(config.opt1.opt2), type(config.opt3))
2 1 <class 'int'> <class 'int'>
>>> parse_config('test.json', **{'opt1__opt2': '2', 'opt3': 1.0}, strict_cast=False)
>>> print(config.opt1.opt2, config.opt3, type(config.opt1.opt2), type(config.opt3))
2 1.0 <class 'str'> <class 'float'>
"""
config = _loader(config_fname, verbose)
if kwargs:
_print(f'overwriting configurations: {kwargs}', verbose)
for arg_key, arg_val in kwargs.items():
keys = arg_key.split(delimiter)
n_keys = len(keys)
_config = config
for idx, _key in enumerate(keys):
if n_keys - 1 == idx:
if strict_cast:
typecast = type(_config[_key])
_config[_key] = typecast(arg_val)
else:
_config[_key] = arg_val
else:
_config = _config[_key]
config = munch.munchify(config)
return config