-
Notifications
You must be signed in to change notification settings - Fork 1
/
params.py
75 lines (66 loc) · 4.15 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
from distutils import util
def parse_args():
parser = argparse.ArgumentParser()
## Basics
parser.add_argument('--config_file', help="Configuration file containing parameters",
type=str)
parser.add_argument('--pool', type=str, required=True) # max, last or last1, mean
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--model_path', type=str, default='') # optional, in case of test, if not standard format
parser.add_argument('--mode', type=str, required=True) # train, test, resume, curriculum
parser.add_argument('--debug', type=int, default=0) # 1 for debug mode
parser.add_argument('--task', type=str, default='IMDB') # IMDB, ...
parser.add_argument('--data_size', type=str, default='10K') # 1K, 2K, 5K, 10K, 25K
parser.add_argument('--customlstm', type=int, default=1, choices=[0,1])
parser.add_argument('--log', type=int, default=0, choices=[0,1])
## Exps
parser.add_argument('--ood', type=int, default = 0) # Out of distribution
parser.add_argument('--vec', type=int, default = 0) # 0 or 1 --> for random vectors
parser.add_argument('--wiki', type=str, default = "none") # left, mid, right or none --> for wikipedia words exp --> wiki_left means original is on left
parser.add_argument('--gradients', default=False, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--ratios', default=False, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--initial', default=False, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--gates', default=False, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--NWI', type=int, default = 0) # For Positional variation
parser.add_argument('--explain', type=int, default = 0) # For Explanation in case of attention, max
## Hyper-parameters - amsgrad
parser.add_argument('--num_layers', default=1, type=int)
parser.add_argument('--bidirectional', default=True, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--glove', default=True, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--use_embedding', default=True, type=lambda x:bool(util.strtobool(x))) #For MNIST keep False
parser.add_argument('--use_bert', default=False, type=lambda x:bool(util.strtobool(x))) #For MNIST keep False
parser.add_argument('--epochs', type=int, default = 10)
parser.add_argument('--seed', type=int, default = 1234)
parser.add_argument('--optimizer', type=str, default='Adam')
parser.add_argument('--amsgrad', default=False, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--nesterov', default=False, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--lr', type=float, default = 1e-3)
parser.add_argument('--momentum', type=float, default = 0.9)
parser.add_argument('--clip', type=float, default=0)
parser.add_argument('--patience', type=int, default=20)
parser.add_argument('--embed_dim', type=int, default = 100)
parser.add_argument('--hidden_dim', type=int, default = 256)
parser.add_argument('--cap_vocab', default=True, type=lambda x:bool(util.strtobool(x)))
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--forget_bias', type=float, default=1)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--drop_connect', type=float, default=0)
parser.add_argument('--penalty_power', type=int, default=4)
parser.add_argument('--teach_small', type=int, default=0)
parser.add_argument('--drop_strategy', type=int, default=0)
parser.add_argument('--freeze_embedding', type=int, default=0)
parser.add_argument('--req_ex', type=int, default=500)
return parser
def add_config(args):
data = yaml.load(open(args.config_file,'r'))
args_dict = args.__dict__
for key, value in data.items():
if('--'+key in sys.argv and args_dict[key] != None): ## Giving higher priority to arguments passed in cli
continue
if isinstance(value, list):
args_dict[key] = []
args_dict[key].extend(value)
else:
args_dict[key] = value
return args