Skip to content

Commit

Permalink
Remove hardcoded runs/ output directory (#1654)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasper-xian authored Sep 26, 2023
1 parent 88f1f5b commit 142c774
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyserini/2cr/beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run_conditions(args):

print(f' - dataset: {dataset}')

runfile = os.path.join(args.directory, f'runs/run.beir.{name}.{dataset}.txt')
runfile = os.path.join(args.directory, f'run.beir.{name}.{dataset}.txt')
cmd = Template(cmd_template).substitute(dataset=dataset, output=runfile)

if args.display_commands:
Expand Down
26 changes: 13 additions & 13 deletions pyserini/2cr/odqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
NQ_TOPICS = 'nq-test'
PRINT_TQA_TOPICS = 'TriviaQA'
PRINT_NQ_TOPICS = 'Natural Question'
TQA_DKRR_RUN = f'runs/run.odqa.DPR-DKRR.{TQA_TOPICS}.hits-100.txt'
NQ_DKRR_RUN = f'runs/run.odqa.DPR-DKRR.{NQ_TOPICS}.hits-100.txt'
TQA_DKRR_RUN = f'run.odqa.DPR-DKRR.{TQA_TOPICS}.hits-100.txt'
NQ_DKRR_RUN = f'run.odqa.DPR-DKRR.{NQ_TOPICS}.hits-100.txt'
# HITS_1K = set(['GarT5-RRF', 'DPR-DKRR'])

GARRRF_LS = ['answers','titles','sentences']
Expand Down Expand Up @@ -211,25 +211,25 @@ def generate_report(args):
cmd_template_nq = condition_nq['command']
if 'RRF' in name:
if name == 'GarT5-RRF':
runfile_tqa = [os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_tqa))]
runfile_nq = [os.path.join(args.directory, f'runs/run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_nq))]
runfile_tqa = [os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_tqa))]
runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_nq))]
tqa_fused_run.update({name: runfile_tqa[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')})
nq_fused_run.update({name: runfile_nq[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')})
jsonfile_tqa = tqa_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '')
jsonfile_nq = nq_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '')
elif name == 'GarT5RRF-DKRR-RRF':
jsonfile_tqa = os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.json')
jsonfile_nq = os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.json')
jsonfile_tqa = os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.json')
jsonfile_nq = os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.json')
tqa_fused_run.update({name: jsonfile_tqa.replace('.json','.txt')})
nq_fused_run.update({name: jsonfile_nq.replace('.json','.txt')})
else:
raise NameError('Wrong model name in yaml config')
else:
if 'dpr-topics' in name:
runfile_nq = [os.path.join(args.directory, f'runs/run.odqa.{name}.dpr-nq-test.hits-100.txt')]
runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.dpr-nq-test.hits-100.txt')]
else:
runfile_nq = [os.path.join(args.directory, f'runs/run.odqa.{name}.{NQ_TOPICS}.hits-100.txt')]
runfile_tqa = [os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.hits-100.txt')]
runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.{NQ_TOPICS}.hits-100.txt')]
runfile_tqa = [os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.hits-100.txt')]
jsonfile_tqa = runfile_tqa[0].replace('.answers', '').replace('.txt', '.json')
jsonfile_nq = runfile_nq[0].replace('.answers', '').replace('.txt', '.json')

Expand Down Expand Up @@ -342,9 +342,9 @@ def run_conditions(args):

# running retrieval
if name == "GarT5-RRF":
runfile = [os.path.join(args.directory, f'runs/run.odqa.{name}.{topics}.{i}.hits-{hits}.txt') for i in GARRRF_LS]
runfile = [os.path.join(args.directory, f'run.odqa.{name}.{topics}.{i}.hits-{hits}.txt') for i in GARRRF_LS]
else:
runfile = [os.path.join(args.directory, f'runs/run.odqa.{name}.{topics}.hits-{hits}.txt')]
runfile = [os.path.join(args.directory, f'run.odqa.{name}.{topics}.hits-{hits}.txt')]

if name != "GarT5RRF-DKRR-RRF":
cmd = [Template(cmd_template[i]).substitute(output=runfile[i]) for i in range(len(runfile))]
Expand All @@ -365,9 +365,9 @@ def run_conditions(args):
output = ''
if name == 'GarT5-RRF':
runs = runfile
output = os.path.join(args.directory, f'runs/run.odqa.{name}.{topics}.hits-{hits}.fusion.txt')
output = os.path.join(args.directory, f'run.odqa.{name}.{topics}.hits-{hits}.fusion.txt')
elif name == 'GarT5RRF-DKRR-RRF':
runs = [os.path.join(args.directory, f'runs/run.odqa.DPR-DKRR.{topics}.hits-1000.txt'), os.path.join(args.directory, f'runs/run.odqa.GarT5-RRF.{topics}.hits-1000.fusion.txt')]
runs = [os.path.join(args.directory, f'run.odqa.DPR-DKRR.{topics}.hits-1000.txt'), os.path.join(args.directory, f'run.odqa.GarT5-RRF.{topics}.hits-1000.fusion.txt')]
output = runfile[0].replace('.txt','.fusion.txt')
else:
raise NameError('Unexpected model name')
Expand Down

0 comments on commit 142c774

Please sign in to comment.