-
Notifications
You must be signed in to change notification settings - Fork 8
/
config.py
123 lines (105 loc) · 4.07 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import division
from __future__ import print_function
import numpy as np
from easydict import EasyDict as edict
__C = edict()
cfg = __C
__C.gpu_id = 0
__C.num_workers = 4
__C.multi_gpus = True
__C.seed = 666
# training options
__C.train = edict()
__C.train.restore = False
__C.train.lr = 0.0001
__C.train.batch_size = 32
__C.train.max_epochs = 25
__C.train.motion_dim = 1024
__C.train.appearance_dim = 1024
__C.train.word_dim = 300
__C.train.module_dim = 512
__C.train.train_num = 0 # Default 0 for full train set
__C.train.glove = True
# __C.train.k_max_frame_level = 16
# __C.train.k_max_clip_level = 8
# __C.train.spl_resolution = 1
__C.train = dict(__C.train)
# validation
__C.val = edict()
__C.val.flag = True
__C.val.val_num = 0 # Default 0 for full val set
__C.val = dict(__C.val)
# test
__C.test = edict()
__C.test.test_num = 0 # Default 0 for full test set
__C.test.write_preds = False
__C.test = dict(__C.test)
# dataset options
__C.dataset = edict()
__C.dataset.name = 'sutd-qa' # ['sutd-qa', 'tgif-qa', 'msrvtt-qa', 'msvd-qa']
__C.dataset.question_type = 'none' #['frameqa', 'count', 'transition', 'action', 'none']
__C.dataset.data_dir = ''
__C.dataset.appearance_feat = '{}_{}_appearance_feat_ResNet.h5'
__C.dataset.motion_feat = '{}_{}_motion_feat_S3D.h5'
__C.dataset.appearance_dict = '{}_{}_appearance_feat_ResNet_dict.h5'
__C.dataset.motion_dict = '{}_{}_motion_feat_S3D_dict.h5'
__C.dataset.vocab_json = '{}_{}_vocab.json'
__C.dataset.vocab_subject_json = '{}_{}_vocab_subject.json'
__C.dataset.vocab_relation_json = '{}_{}_vocab_relation.json'
__C.dataset.vocab_object_json = '{}_{}_vocab_object.json'
__C.dataset.train_question_pt = '{}_{}_train_questions.pt'
__C.dataset.train_question_subject_pt = '{}_{}_train_questions_subject.pt'
__C.dataset.train_question_relation_pt = '{}_{}_train_questions_relation.pt'
__C.dataset.train_question_object_pt = '{}_{}_train_questions_object.pt'
__C.dataset.val_question_pt = '{}_{}_test_questions.pt'
__C.dataset.val_question_subject_pt = '{}_{}_test_questions_subject.pt'
__C.dataset.val_question_relation_pt = '{}_{}_test_questions_relation.pt'
__C.dataset.val_question_object_pt = '{}_{}_test_questions_object.pt'
__C.dataset.test_question_pt = '{}_{}_test_questions.pt'
__C.dataset.test_question_subject_pt = '{}_{}_test_questions_subject.pt'
__C.dataset.test_question_relation_pt = '{}_{}_test_questions_relation.pt'
__C.dataset.test_question_object_pt = '{}_{}_test_questions_object.pt'
__C.dataset.test_question_pt = '{}_{}_test_questions.pt'
__C.dataset.save_dir = ''
__C.video_file = 'datasets/SUTD-TrafficQA/annotations/R3_all.jsonl'
__C.video_dir = 'datasets/SUTD-TrafficQA/raw_videos/'
__C.dataset = dict(__C.dataset)
# experiment name
__C.exp_name = 'defaultExp'
# credit https://github.com/tohinz/pytorch-mac-network/blob/master/code/config.py
def merge_cfg(yaml_cfg, cfg):
if type(yaml_cfg) is not edict:
return
for k, v in yaml_cfg.items():
if not k in cfg:
raise KeyError('{} is not a valid config key'.format(k))
old_type = type(cfg[k])
if old_type is not type(v):
if isinstance(cfg[k], np.ndarray):
v = np.array(v, dtype=cfg[k].dtype)
elif isinstance(cfg[k], list):
v = v.split(",")
v = [int(_v) for _v in v]
elif cfg[k] is None:
if v == "None":
continue
else:
v = v
else:
raise ValueError(('Type mismatch ({} vs. {}) '
'for config key: {}').format(type(cfg[k]),
type(v), k))
# recursively merge dicts
if type(v) is edict:
try:
merge_cfg(yaml_cfg[k], cfg[k])
except:
print('Error under config key: {}'.format(k))
raise
else:
cfg[k] = v
def cfg_from_file(file_name):
import yaml
with open(file_name, 'r') as f:
yaml_cfg = edict(yaml.load(f, Loader=yaml.FullLoader))
merge_cfg(yaml_cfg, __C)