From 142c774a303c906ee245913bc7e714b165074b77 Mon Sep 17 00:00:00 2001 From: Jasper Xian <41269031+jasper-xian@users.noreply.github.com> Date: Mon, 25 Sep 2023 23:55:41 -0400 Subject: [PATCH] Remove hardcoded runs/ output directory (#1654) --- pyserini/2cr/beir.py | 2 +- pyserini/2cr/odqa.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyserini/2cr/beir.py b/pyserini/2cr/beir.py index 5e8692807..36f1be3e0 100644 --- a/pyserini/2cr/beir.py +++ b/pyserini/2cr/beir.py @@ -187,7 +187,7 @@ def run_conditions(args): print(f' - dataset: {dataset}') - runfile = os.path.join(args.directory, f'runs/run.beir.{name}.{dataset}.txt') + runfile = os.path.join(args.directory, f'run.beir.{name}.{dataset}.txt') cmd = Template(cmd_template).substitute(dataset=dataset, output=runfile) if args.display_commands: diff --git a/pyserini/2cr/odqa.py b/pyserini/2cr/odqa.py index cb7f6d752..a6ee5d97a 100644 --- a/pyserini/2cr/odqa.py +++ b/pyserini/2cr/odqa.py @@ -49,8 +49,8 @@ NQ_TOPICS = 'nq-test' PRINT_TQA_TOPICS = 'TriviaQA' PRINT_NQ_TOPICS = 'Natural Question' -TQA_DKRR_RUN = f'runs/run.odqa.DPR-DKRR.{TQA_TOPICS}.hits-100.txt' -NQ_DKRR_RUN = f'runs/run.odqa.DPR-DKRR.{NQ_TOPICS}.hits-100.txt' +TQA_DKRR_RUN = f'run.odqa.DPR-DKRR.{TQA_TOPICS}.hits-100.txt' +NQ_DKRR_RUN = f'run.odqa.DPR-DKRR.{NQ_TOPICS}.hits-100.txt' # HITS_1K = set(['GarT5-RRF', 'DPR-DKRR']) GARRRF_LS = ['answers','titles','sentences'] @@ -211,25 +211,25 @@ def generate_report(args): cmd_template_nq = condition_nq['command'] if 'RRF' in name: if name == 'GarT5-RRF': - runfile_tqa = [os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_tqa))] - runfile_nq = [os.path.join(args.directory, f'runs/run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_nq))] + runfile_tqa = [os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_tqa))] + runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_nq))] tqa_fused_run.update({name: runfile_tqa[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')}) nq_fused_run.update({name: runfile_nq[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')}) jsonfile_tqa = tqa_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '') jsonfile_nq = nq_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '') elif name == 'GarT5RRF-DKRR-RRF': - jsonfile_tqa = os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.json') - jsonfile_nq = os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.json') + jsonfile_tqa = os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.json') + jsonfile_nq = os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.json') tqa_fused_run.update({name: jsonfile_tqa.replace('.json','.txt')}) nq_fused_run.update({name: jsonfile_nq.replace('.json','.txt')}) else: raise NameError('Wrong model name in yaml config') else: if 'dpr-topics' in name: - runfile_nq = [os.path.join(args.directory, f'runs/run.odqa.{name}.dpr-nq-test.hits-100.txt')] + runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.dpr-nq-test.hits-100.txt')] else: - runfile_nq = [os.path.join(args.directory, f'runs/run.odqa.{name}.{NQ_TOPICS}.hits-100.txt')] - runfile_tqa = [os.path.join(args.directory, f'runs/run.odqa.{name}.{TQA_TOPICS}.hits-100.txt')] + runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.{NQ_TOPICS}.hits-100.txt')] + runfile_tqa = [os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.hits-100.txt')] jsonfile_tqa = runfile_tqa[0].replace('.answers', '').replace('.txt', '.json') jsonfile_nq = runfile_nq[0].replace('.answers', '').replace('.txt', '.json') @@ -342,9 +342,9 @@ def run_conditions(args): # running retrieval if name == "GarT5-RRF": - runfile = [os.path.join(args.directory, f'runs/run.odqa.{name}.{topics}.{i}.hits-{hits}.txt') for i in GARRRF_LS] + runfile = [os.path.join(args.directory, f'run.odqa.{name}.{topics}.{i}.hits-{hits}.txt') for i in GARRRF_LS] else: - runfile = [os.path.join(args.directory, f'runs/run.odqa.{name}.{topics}.hits-{hits}.txt')] + runfile = [os.path.join(args.directory, f'run.odqa.{name}.{topics}.hits-{hits}.txt')] if name != "GarT5RRF-DKRR-RRF": cmd = [Template(cmd_template[i]).substitute(output=runfile[i]) for i in range(len(runfile))] @@ -365,9 +365,9 @@ def run_conditions(args): output = '' if name == 'GarT5-RRF': runs = runfile - output = os.path.join(args.directory, f'runs/run.odqa.{name}.{topics}.hits-{hits}.fusion.txt') + output = os.path.join(args.directory, f'run.odqa.{name}.{topics}.hits-{hits}.fusion.txt') elif name == 'GarT5RRF-DKRR-RRF': - runs = [os.path.join(args.directory, f'runs/run.odqa.DPR-DKRR.{topics}.hits-1000.txt'), os.path.join(args.directory, f'runs/run.odqa.GarT5-RRF.{topics}.hits-1000.fusion.txt')] + runs = [os.path.join(args.directory, f'run.odqa.DPR-DKRR.{topics}.hits-1000.txt'), os.path.join(args.directory, f'run.odqa.GarT5-RRF.{topics}.hits-1000.fusion.txt')] output = runfile[0].replace('.txt','.fusion.txt') else: raise NameError('Unexpected model name')