forked from castorini/birch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
36 lines (30 loc) · 2.29 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from argparse import ArgumentParser
def get_args():
parser = ArgumentParser(description='birch')
parser.add_argument('--mode', default='retrieval', help='[training, inference, retrieval]')
parser.add_argument('--output_path', default='out.tmp', help='Name of log file')
parser.add_argument('--data_path', default='data')
parser.add_argument('--anserini_path', default='../Anserini', help='Path to Anserini root')
parser.add_argument('--collection', default='robust04', help='[mb, robust04, core17, core18]')
parser.add_argument('--trec_eval_path', default='eval/trec_eval.9.0.4/trec_eval')
# Interactive
parser.add_argument('--interactive', action='store_true', default=False, help='Batch evaluation if not set')
parser.add_argument('--query', default='hubble space telescope', help='Query string')
parser.add_argument('--interactive_name', default='query_sents', help='Name of output sentence results from query')
# Retrieval
parser.add_argument('--experiment', default=None, help='Experiment name for logging')
parser.add_argument('--index_path', default='lucene-index.robust04.pos+docvectors+rawdocs', help='Path to Lucene index')
# Training
parser.add_argument('--device', default='cpu', help='[cuda, cpu]')
parser.add_argument('--model_path', default='models/saved.tmp', help='Path to pretrained model')
parser.add_argument('--predict_path', default='data/predictions/predict.tmp')
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--learning_rate', default=1e-5, type=float)
parser.add_argument('--num_train_epochs', default=3, type=int)
parser.add_argument('--eval_steps', default=-1, type=int, help='Number of evaluation steps, -1 for evaluation per epoch')
parser.add_argument('--warmup_proportion', default=0.1, type=float, help='Proportion of training to perform linear learning rate warmup. E.g., 0.1 = 10%% of training.')
parser.add_argument('--local_model', default=None, help='[None, path to local model file]')
parser.add_argument('--local_tokenizer', default=None, help='[None, path to local vocab file]')
parser.add_argument('--load_trained', action='store_true', default=False, help='Load pretrained model if True')
args, other = parser.parse_known_args()
return args, other