diff --git a/pyserini/2cr/beir.py b/pyserini/2cr/beir.py index ebe54a1b0..4a3f5c3af 100644 --- a/pyserini/2cr/beir.py +++ b/pyserini/2cr/beir.py @@ -64,17 +64,18 @@ 'scifact' ] + def format_run_command(raw): - return raw.replace('--topics', '\\\n --topics')\ - .replace('--index', '\\\n --index')\ - .replace('--encoder-class', '\\\n --encoder-class')\ - .replace('--output ', '\\\n --output ')\ + return raw.replace('--topics', '\\\n --topics') \ + .replace('--index', '\\\n --index') \ + .replace('--encoder-class', '\\\n --encoder-class') \ + .replace('--output ', '\\\n --output ') \ .replace('--output-format trec', '\\\n --output-format trec \\\n ') \ .replace('--hits ', '\\\n --hits ') def format_eval_command(raw): - return raw.replace('-c ', '\\\n -c ')\ + return raw.replace('-c ', '\\\n -c ') \ .replace('run.', '\\\n run.') @@ -85,16 +86,19 @@ def read_file(f): return text -def list_conditions(args): + +def list_conditions(): with open(pkg_resources.resource_filename(__name__, 'beir.yaml')) as f: yaml_data = yaml.safe_load(f) for condition in yaml_data['conditions']: print(condition['name']) - -def list_datasets(args): + + +def list_datasets(): for dataset in beir_keys: print(dataset) + def generate_report(args): table = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0))) commands = defaultdict(lambda: defaultdict(lambda: '')) @@ -149,8 +153,7 @@ def generate_report(args): eval_cmd2=eval_commands[dataset]["bm25-multifield"].rstrip(), eval_cmd3=eval_commands[dataset]["splade-distil-cocodenser-medium"].rstrip(), eval_cmd4=eval_commands[dataset]["contriever"].rstrip(), - eval_cmd5=eval_commands[dataset]["contriever-msmarco"].rstrip(), - ) + eval_cmd5=eval_commands[dataset]["contriever-msmarco"].rstrip()) html_rows.append(s) row_cnt += 1 @@ -159,6 +162,7 @@ def generate_report(args): with open(args.output, 'w') as out: out.write(Template(html_template).substitute(title='BEIR', rows=all_rows)) + def run_conditions(args): start = time.time() @@ -253,16 +257,17 @@ def run_conditions(args): f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.4f}{table[dataset]["contriever-msmarco"]["R@100"]:8.4f}') print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14) print('avg' + ' ' * 22 + f'{final_scores["bm25-flat"]["nDCG@10"]:8.4f}{final_scores["bm25-flat"]["R@100"]:8.4f} ' + - f'{final_scores["bm25-multifield"]["nDCG@10"]:8.4f}{final_scores["bm25-multifield"]["R@100"]:8.4f} ' + - f'{final_scores["splade-distil-cocodenser-medium"]["nDCG@10"]:8.4f}{final_scores["splade-distil-cocodenser-medium"]["R@100"]:8.4f} ' + - f'{final_scores["contriever"]["nDCG@10"]:8.4f}{final_scores["contriever"]["R@100"]:8.4f} ' + - f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f}') + f'{final_scores["bm25-multifield"]["nDCG@10"]:8.4f}{final_scores["bm25-multifield"]["R@100"]:8.4f} ' + + f'{final_scores["splade-distil-cocodenser-medium"]["nDCG@10"]:8.4f}{final_scores["splade-distil-cocodenser-medium"]["R@100"]:8.4f} ' + + f'{final_scores["contriever"]["nDCG@10"]:8.4f}{final_scores["contriever"]["R@100"]:8.4f} ' + + f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f}') end = time.time() print('\n') - print(f'Total elapsed time: {end - start:.0f}s') - + print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr') + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Generate regression matrix for BeIR corpora.') # To list all conditions/datasets @@ -282,11 +287,11 @@ def run_conditions(args): args = parser.parse_args() if args.list_conditions: - list_conditions(args) + list_conditions() sys.exit() if args.list_datasets: - list_datasets(args) + list_datasets() sys.exit() if args.generate_report: diff --git a/pyserini/2cr/miracl.py b/pyserini/2cr/miracl.py index f6a8926e8..2ad44155e 100644 --- a/pyserini/2cr/miracl.py +++ b/pyserini/2cr/miracl.py @@ -188,8 +188,7 @@ def generate_table_rows(table, row_template, commands, eval_commands, table_id, eval_cmd15=f'{eval_commands[keys["th"]][metric]}', eval_cmd16=f'{eval_commands[keys["zh"]][metric]}', eval_cmd17=f'{eval_commands[keys["de"]][metric]}', - eval_cmd18=f'{eval_commands[keys["yo"]][metric]}' - ) + eval_cmd18=f'{eval_commands[keys["yo"]][metric]}') s = s.replace("0.000", "--") html_rows.append(s) @@ -389,15 +388,15 @@ def run_conditions(args): and split == 'dev' and metric == 'nDCG@10' and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \ (name == 'mdpr-tied-pft-msmarco-ft-all.ru' - # Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3932 -> expected 0.3933 + # Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3932 -> expected 0.3933 and split == 'dev' and metric == 'nDCG@10' and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \ (name == 'bm25-mdpr-tied-pft-msmarco-hybrid.te' - # Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.6000 -> expected 0.5999 + # Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.6000 -> expected 0.5999 and split == 'train' and metric == 'nDCG@10' and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \ (name == 'mcontriever-tied-pft-msmarco.id' - # Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3748 -> expected 0.3749 + # Flaky on Jimmy's Mac Studio (Apple M1 Ultra), nDCG@10: 0.3748 -> expected 0.3749 and split == 'train' and metric == 'nDCG@10' and math.isclose(score, float(expected[metric]), abs_tol=2e-4)): result_str = okish_str @@ -415,7 +414,7 @@ def run_conditions(args): print_results(table, metric, split) end = time.time() - print(f'Total elapsed time: {end - start:.0f}s') + print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr') if __name__ == '__main__': diff --git a/pyserini/2cr/mrtydi.py b/pyserini/2cr/mrtydi.py index ab756079c..9c437770f 100644 --- a/pyserini/2cr/mrtydi.py +++ b/pyserini/2cr/mrtydi.py @@ -58,17 +58,17 @@ def format_run_command(raw): - return raw.replace('--lang', '\\\n --lang')\ - .replace('--encoder', '\\\n --encoder')\ - .replace('--topics', '\\\n --topics')\ - .replace('--index', '\\\n --index')\ - .replace('--output ', '\\\n --output ')\ + return raw.replace('--lang', '\\\n --lang') \ + .replace('--encoder', '\\\n --encoder') \ + .replace('--topics', '\\\n --topics') \ + .replace('--index', '\\\n --index') \ + .replace('--output ', '\\\n --output ') \ .replace('--batch ', '\\\n --batch ') \ .replace('--threads 12', '--threads 12 \\\n ') def format_eval_command(raw): - return raw.replace('-c ', '\\\n -c ')\ + return raw.replace('-c ', '\\\n -c ') \ .replace(raw.split()[-1], f'\\\n {raw.split()[-1]}') @@ -164,8 +164,7 @@ def generate_table_rows(table, row_template, commands, eval_commands, table_id, eval_cmd8=f'{eval_commands[keys["ru"]][metric]}', eval_cmd9=f'{eval_commands[keys["sw"]][metric]}', eval_cmd10=f'{eval_commands[keys["te"]][metric]}', - eval_cmd11=f'{eval_commands[keys["th"]][metric]}' - ) + eval_cmd11=f'{eval_commands[keys["th"]][metric]}') html_rows.append(s) row_cnt += 1 @@ -292,7 +291,7 @@ def run_conditions(args): print_results(table, metric, split) end = time.time() - print(f'Total elapsed time: {end - start:.0f}s') + print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr') if __name__ == '__main__': diff --git a/pyserini/2cr/msmarco.py b/pyserini/2cr/msmarco.py index 67f193439..0ce996ef3 100644 --- a/pyserini/2cr/msmarco.py +++ b/pyserini/2cr/msmarco.py @@ -271,7 +271,7 @@ def format_command(raw): # Format hybrid commands differently. if 'pyserini.search.hybrid' in raw: return raw.replace('dense', '\\\n dense ') \ - .replace('--encoder', '\\\n --encoder')\ + .replace('--encoder', '\\\n --encoder') \ .replace('sparse', '\\\n sparse') \ .replace('fusion', '\\\n fusion') \ .replace('run --', '\\\n run --') \ @@ -282,12 +282,12 @@ def format_command(raw): # We want these on a separate line for better readability, but note that sometimes that might # be the end of the command, in which case we don't want to add an extra line break. return raw.replace('--topics', '\\\n --topics') \ - .replace('--threads', '\\\n --threads')\ - .replace('--index', '\\\n --index')\ - .replace('--output ', '\\\n --output ')\ - .replace('--encoder', '\\\n --encoder')\ - .replace('--onnx-encoder', '\\\n --onnx-encoder')\ - .replace('--encoded-corpus', '\\\n --encoded-corpus')\ + .replace('--threads', '\\\n --threads') \ + .replace('--index', '\\\n --index') \ + .replace('--output ', '\\\n --output ') \ + .replace('--encoder', '\\\n --encoder') \ + .replace('--onnx-encoder', '\\\n --onnx-encoder') \ + .replace('--encoded-corpus', '\\\n --encoded-corpus') \ .replace('.txt ', '.txt \\\n ') @@ -339,7 +339,7 @@ def generate_report(args): row_id = condition['display-row'] if 'display-row' in condition else '' cmd_template = condition['command'] - row_ids[name] =row_id + row_ids[name] = row_id table_keys[name] = display for topic_set in condition['topics']: @@ -483,7 +483,6 @@ def run_conditions(args): topic_key = topic_set['topic_key'] eval_key = topic_set['eval_key'] - short_topic_key = '' if args.collection == 'msmarco-v1-passage' or args.collection == 'msmarco-v1-doc': short_topic_key = find_msmarco_table_topic_set_key_v1(topic_key) else: @@ -608,7 +607,7 @@ def run_conditions(args): end = time.time() print('\n') - print(f'Total elapsed time: {end - start:.0f}s') + print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr') if __name__ == '__main__': diff --git a/pyserini/2cr/odqa.py b/pyserini/2cr/odqa.py index f59167350..56da66eb4 100644 --- a/pyserini/2cr/odqa.py +++ b/pyserini/2cr/odqa.py @@ -51,10 +51,9 @@ PRINT_NQ_TOPICS = 'Natural Question' TQA_DKRR_RUN = f'run.odqa.DPR-DKRR.{TQA_TOPICS}.hits-100.txt' NQ_DKRR_RUN = f'run.odqa.DPR-DKRR.{NQ_TOPICS}.hits-100.txt' -# HITS_1K = set(['GarT5-RRF', 'DPR-DKRR']) -GARRRF_LS = ['answers','titles','sentences'] -HITS_1K = set(['GarT5-RRF', 'DPR-DKRR', 'DPR-Hybrid']) +GARRRF_LS = ['answers', 'titles', 'sentences'] +HITS_1K = {'GarT5-RRF', 'DPR-DKRR', 'DPR-Hybrid'} def print_results(table, metric, topics): @@ -78,6 +77,7 @@ def format_run_command(raw): .replace('--bm25', '\\\n --bm25')\ .replace('--hits 100', '\\\n --hits 100') + def format_hybrid_search_command(raw): return raw.replace('--encoder', '\\\n\t--encoder')\ .replace(' dense', ' \\\n dense ')\ @@ -90,11 +90,12 @@ def format_hybrid_search_command(raw): .replace('--lang', '\\\n\t--lang')\ .replace('--hits 100', '\\\n\t--hits 100') + def format_convert_command(raw): return raw.replace('--topics', '\\\n --topics')\ .replace('--index', '\\\n --index')\ .replace('--input', '\\\n --input')\ - .replace('--output', '\\\n --output')\ + .replace('--output', '\\\n --output') def format_eval_command(raw): @@ -109,7 +110,8 @@ def read_file(f): return text -def list_conditions(args): + +def list_conditions(): for model in models['models']: print(model) @@ -126,62 +128,60 @@ def generate_table_rows(table, table_id, commands, convert_commands, eval_comman if model == "GarT5-RRF": s = Template(row_template_garrrf) s = s.substitute(table_cnt=table_id, - row_cnt=row_cnt, - model=model, - TQA_Top20=table[model][TQA_TOPICS]["Top20"], - TQA_Top100=table[model][TQA_TOPICS]["Top100"], - NQ_Top20=table[model][NQ_TOPICS]["Top20"], - NQ_Top100=table[model][NQ_TOPICS]["Top100"], - cmd1=f'{commands[model][TQA_TOPICS][0]}', - cmd2=f'{commands[model][TQA_TOPICS][1]}', - cmd3=f'{commands[model][TQA_TOPICS][2]}', - cmd4=f'{commands[model][NQ_TOPICS][0]}', - cmd5=f'{commands[model][NQ_TOPICS][1]}', - cmd6=f'{commands[model][NQ_TOPICS][2]}', - fusion_cmd1=fusion_cmd_tqa[0], - fusion_cmd2=fusion_cmd_nq[0], - convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}', - convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}', - eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}', - eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}' - ) + row_cnt=row_cnt, + model=model, + TQA_Top20=table[model][TQA_TOPICS]["Top20"], + TQA_Top100=table[model][TQA_TOPICS]["Top100"], + NQ_Top20=table[model][NQ_TOPICS]["Top20"], + NQ_Top100=table[model][NQ_TOPICS]["Top100"], + cmd1=f'{commands[model][TQA_TOPICS][0]}', + cmd2=f'{commands[model][TQA_TOPICS][1]}', + cmd3=f'{commands[model][TQA_TOPICS][2]}', + cmd4=f'{commands[model][NQ_TOPICS][0]}', + cmd5=f'{commands[model][NQ_TOPICS][1]}', + cmd6=f'{commands[model][NQ_TOPICS][2]}', + fusion_cmd1=fusion_cmd_tqa[0], + fusion_cmd2=fusion_cmd_nq[0], + convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}', + convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}', + eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}', + eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}') elif model == "GarT5RRF-DKRR-RRF": s = Template(row_template_rrf) s = s.substitute(table_cnt=table_id, - row_cnt=row_cnt, - model=model, - TQA_Top20=table[model][TQA_TOPICS]["Top20"], - TQA_Top100=table[model][TQA_TOPICS]["Top100"], - NQ_Top20=table[model][NQ_TOPICS]["Top20"], - NQ_Top100=table[model][NQ_TOPICS]["Top100"], - fusion_cmd1=fusion_cmd_tqa[1], - fusion_cmd2=fusion_cmd_nq[1], - convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}', - convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}', - eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}', - eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}' - ) + row_cnt=row_cnt, + model=model, + TQA_Top20=table[model][TQA_TOPICS]["Top20"], + TQA_Top100=table[model][TQA_TOPICS]["Top100"], + NQ_Top20=table[model][NQ_TOPICS]["Top20"], + NQ_Top100=table[model][NQ_TOPICS]["Top100"], + fusion_cmd1=fusion_cmd_tqa[1], + fusion_cmd2=fusion_cmd_nq[1], + convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}', + convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}', + eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}', + eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}') else: s = Template(row_template) s = s.substitute(table_cnt=table_id, - row_cnt=row_cnt, - model=model, - TQA_Top20=table[model][TQA_TOPICS]["Top20"], - TQA_Top100=table[model][TQA_TOPICS]["Top100"], - NQ_Top20=table[model][NQ_TOPICS]["Top20"], - NQ_Top100=table[model][NQ_TOPICS]["Top100"], - cmd1=commands[model][TQA_TOPICS][0], - cmd2=commands[model][NQ_TOPICS][0], - convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}', - convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}', - eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}', - eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}' - ) + row_cnt=row_cnt, + model=model, + TQA_Top20=table[model][TQA_TOPICS]["Top20"], + TQA_Top100=table[model][TQA_TOPICS]["Top100"], + NQ_Top20=table[model][NQ_TOPICS]["Top20"], + NQ_Top100=table[model][NQ_TOPICS]["Top100"], + cmd1=commands[model][TQA_TOPICS][0], + cmd2=commands[model][NQ_TOPICS][0], + convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}', + convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}', + eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}', + eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}') html_rows.append(s) row_cnt += 1 return html_rows + def generate_report(args): table = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0))) commands = defaultdict(lambda: defaultdict(lambda: [])) @@ -193,15 +193,12 @@ def generate_report(args): tqa_yaml_path = pkg_resources.resource_filename(__name__, 'triviaqa.yaml') nq_yaml_path = pkg_resources.resource_filename(__name__, 'naturalquestion.yaml') - garrrf_ls = ['answers','titles','sentences'] - prefusion_runfile_tqa = [] - prefusion_runfile_nq = [] + garrrf_ls = ['answers', 'titles', 'sentences'] fusion_cmd_tqa = [] fusion_cmd_nq = [] tqa_fused_run = {} nq_fused_run = {} - with open(tqa_yaml_path) as f_tqa, open(nq_yaml_path) as f_nq: tqa_yaml_data = yaml.safe_load(f_tqa) nq_yaml_data = yaml.safe_load(f_nq) @@ -211,8 +208,12 @@ def generate_report(args): cmd_template_nq = condition_nq['command'] if 'RRF' in name: if name == 'GarT5-RRF': - runfile_tqa = [os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_tqa))] - runfile_nq = [os.path.join(args.directory, f'run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') for i in range(len(cmd_template_nq))] + runfile_tqa = \ + [os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') + for i in range(len(cmd_template_tqa))] + runfile_nq = \ + [os.path.join(args.directory, f'run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt') + for i in range(len(cmd_template_nq))] tqa_fused_run.update({name: runfile_tqa[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')}) nq_fused_run.update({name: runfile_nq[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')}) jsonfile_tqa = tqa_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '') @@ -220,8 +221,8 @@ def generate_report(args): elif name == 'GarT5RRF-DKRR-RRF': jsonfile_tqa = os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.json') jsonfile_nq = os.path.join(args.directory, f'run.odqa.{name}.{TQA_TOPICS}.json') - tqa_fused_run.update({name: jsonfile_tqa.replace('.json','.txt')}) - nq_fused_run.update({name: jsonfile_nq.replace('.json','.txt')}) + tqa_fused_run.update({name: jsonfile_tqa.replace('.json', '.txt')}) + nq_fused_run.update({name: jsonfile_nq.replace('.json', '.txt')}) else: raise NameError('Wrong model name in yaml config') else: @@ -233,8 +234,8 @@ def generate_report(args): jsonfile_tqa = runfile_tqa[0].replace('.answers', '').replace('.txt', '.json') jsonfile_nq = runfile_nq[0].replace('.answers', '').replace('.txt', '.json') - display_runfile_tqa = jsonfile_tqa.replace('.json','.txt') - display_runfile_nq = jsonfile_nq.replace('.json','.txt') + display_runfile_tqa = jsonfile_tqa.replace('.json', '.txt') + display_runfile_nq = jsonfile_nq.replace('.json', '.txt') # fusion commands if "RRF" in name: @@ -245,14 +246,14 @@ def generate_report(args): tqa_runs = ' \\\n\t '.join(runfile_tqa) nq_runs = ' \\\n\t '.join(runfile_nq) - fusion_cmd_tqa.append(f'python -m pyserini.fusion \\\n' + \ - f' --runs {tqa_runs} \\\n' + \ - f' --output {tqa_fused_run[name]} \\\n' - f' --k 100') - fusion_cmd_nq.append(f'python -m pyserini.fusion \\\n' + \ - f' --runs {nq_runs} \\\n' + \ - f' --output {nq_fused_run[name]} \\\n' + \ - f' --k 100') + fusion_cmd_tqa.append(f'python -m pyserini.fusion \\\n' + + f' --runs {tqa_runs} \\\n' + + f' --output {tqa_fused_run[name]} \\\n' + + f' --k 100') + fusion_cmd_nq.append(f'python -m pyserini.fusion \\\n' + + f' --runs {nq_runs} \\\n' + + f' --output {nq_fused_run[name]} \\\n' + + f' --k 100') if name != "GarT5RRF-DKRR-RRF": hits = 100 if name not in HITS_1K else 1000 @@ -266,7 +267,7 @@ def generate_report(args): commands[name][TQA_TOPICS].extend([format_run_command(i) for i in cmd_tqa]) commands[name][NQ_TOPICS].extend([format_run_command(i) for i in cmd_nq]) - # convertion commands: + # conversion commands: if 'dpr-topics' in name: temp_nq_topics = 'dpr-nq-test' else: @@ -274,7 +275,7 @@ def generate_report(args): convert_cmd_tqa = f'python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run ' + \ f'--topics {TQA_TOPICS} ' + \ - f'--index wikipedia-dpr ' +\ + f'--index wikipedia-dpr ' + \ f'--input {display_runfile_tqa} ' + \ f'--output {jsonfile_tqa}' convert_cmd_nq = f'python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run ' + \ @@ -300,16 +301,19 @@ def generate_report(args): table[name][NQ_TOPICS].update(expected_nq) tables_html = [] - html_rows = generate_table_rows(table, 1, commands, convert_commands, eval_commands, fusion_cmd_tqa=fusion_cmd_tqa, fusion_cmd_nq=fusion_cmd_nq) + html_rows = generate_table_rows(table, 1, commands, convert_commands, + eval_commands, fusion_cmd_tqa=fusion_cmd_tqa, fusion_cmd_nq=fusion_cmd_nq) all_rows = '\n'.join(html_rows) tables_html.append(Template(table_template).substitute(desc='Models', rows=all_rows)) with open(args.output, 'w') as out: out.write(Template(html_template).substitute(title=f'Retrieval for Open-Domain QA Datasets', tables=' '.join(tables_html))) + def run_conditions(args): hits = 1000 if args.full_topk else 100 - yaml_path = pkg_resources.resource_filename(__name__, 'triviaqa.yaml') if args.topics == "tqa" else pkg_resources.resource_filename(__name__, 'naturalquestion.yaml') + yaml_path = pkg_resources.resource_filename(__name__, 'triviaqa.yaml') \ + if args.topics == "tqa" else pkg_resources.resource_filename(__name__, 'naturalquestion.yaml') topics = 'dpr-trivia-test' if args.topics == 'tqa' else 'nq-test' start = time.time() table = defaultdict(lambda: defaultdict(lambda: 0.0)) @@ -350,8 +354,7 @@ def run_conditions(args): cmd = [Template(cmd_template[i]).substitute(output=runfile[i]) for i in range(len(runfile))] if hits == 100: cmd = [i + ' --hits 100' for i in cmd] - - + for i in range(len(runfile)): if args.display_commands: print(f'\n```bash\n{format_run_command(cmd[i])}\n```\n') @@ -361,14 +364,13 @@ def run_conditions(args): # fusion if 'RRF' in name: - runs = [] - output = '' if name == 'GarT5-RRF': runs = runfile output = os.path.join(args.directory, f'run.odqa.{name}.{topics}.hits-{hits}.fusion.txt') elif name == 'GarT5RRF-DKRR-RRF': - runs = [os.path.join(args.directory, f'run.odqa.DPR-DKRR.{topics}.hits-1000.txt'), os.path.join(args.directory, f'run.odqa.GarT5-RRF.{topics}.hits-1000.fusion.txt')] - output = runfile[0].replace('.txt','.fusion.txt') + runs = [os.path.join(args.directory, f'run.odqa.DPR-DKRR.{topics}.hits-1000.txt'), + os.path.join(args.directory, f'run.odqa.GarT5-RRF.{topics}.hits-1000.fusion.txt')] + output = runfile[0].replace('.txt', '.fusion.txt') else: raise NameError('Unexpected model name') if not os.path.exists(output) and not args.dry_run: @@ -380,12 +382,12 @@ def run_conditions(args): raise RuntimeError('fusion failed') runfile = [output] - # trec conversion + evaluation + # TREC conversion + evaluation if not args.skip_eval: if not os.path.exists(runfile[0]): continue jsonfile = runfile[0].replace('.txt', '.json') - runfile = jsonfile.replace('.json','.txt') + runfile = jsonfile.replace('.json', '.txt') if not os.path.exists(jsonfile): status = convert_trec_run_to_dpr_retrieval_json(topics, 'wikipedia-dpr-100w', runfile, jsonfile) if status != 0: @@ -409,7 +411,6 @@ def run_conditions(args): table[name][metric] = score[metric] else: table[name][metric] = expected_score - print('') metric_ls = ['Top5', 'Top20', 'Top100', 'Top500', 'Top1000'] metric_ls = metric_ls[:3] if not args.full_topk else metric_ls @@ -417,7 +418,7 @@ def run_conditions(args): print_results(table, metric, topics) end = time.time() - print(f'Total elapsed time: {end - start:.0f}s') + print(f'Total elapsed time: {end - start:.0f}s ~{(end - start)/3600:.1f}hr') if __name__ == '__main__': @@ -439,7 +440,7 @@ def run_conditions(args): args = parser.parse_args() if args.list_conditions: - list_conditions(args) + list_conditions() sys.exit() if args.generate_report: @@ -462,4 +463,4 @@ def run_conditions(args): print('Specifying --all will run all conditions') sys.exit() - run_conditions(args) \ No newline at end of file + run_conditions(args)