From 0b3ec904376d207a36f809944108720c49ff8ce1 Mon Sep 17 00:00:00 2001 From: Jimmy Lin Date: Sun, 27 Aug 2023 09:20:41 -0400 Subject: [PATCH] Fix paper integration tests due to path changes in anserini and other minor issues (#1602) --- integrations/papers/test_ecir2023.py | 2 +- integrations/papers/test_sigir2021.py | 3 +-- integrations/papers/test_sigir2022.py | 5 ++--- pyserini/search/faiss/_searcher.py | 7 +++++++ pyserini/search/lucene/_impact_searcher.py | 4 ++++ pyserini/search/lucene/_searcher.py | 7 +++++++ 6 files changed, 22 insertions(+), 6 deletions(-) diff --git a/integrations/papers/test_ecir2023.py b/integrations/papers/test_ecir2023.py index 4c6535e2b..f894483ed 100644 --- a/integrations/papers/test_ecir2023.py +++ b/integrations/papers/test_ecir2023.py @@ -19,7 +19,7 @@ import os import unittest -from integrations.utils import clean_files, run_command, parse_score, parse_score_qa +from integrations.utils import clean_files, run_command, parse_score_qa class TestECIR2023(unittest.TestCase): diff --git a/integrations/papers/test_sigir2021.py b/integrations/papers/test_sigir2021.py index b59506e6d..e422708cb 100644 --- a/integrations/papers/test_sigir2021.py +++ b/integrations/papers/test_sigir2021.py @@ -169,8 +169,7 @@ def test_section3_3(self): msmarco-passage-dev-subset {output_file}' stdout, stderr = run_command(eval_cmd) score = parse_score_msmarco(stdout, "MRR @10") - self.assertAlmostEqual(score, 0.1872, delta=0.0001) - # Temporary fix: this is Lucene 9 code running on Lucene 8 prebuilt index. + self.assertAlmostEqual(score, 0.1874, delta=0.0001) def tearDown(self): clean_files(self.temp_files) diff --git a/integrations/papers/test_sigir2022.py b/integrations/papers/test_sigir2022.py index 3ff3908fd..ff7ae0215 100644 --- a/integrations/papers/test_sigir2022.py +++ b/integrations/papers/test_sigir2022.py @@ -22,7 +22,7 @@ from integrations.utils import clean_files, run_command, parse_score, parse_score_msmarco -class TestSIGIR2021(unittest.TestCase): +class TestSIGIR2022(unittest.TestCase): def setUp(self): self.temp_files = [] @@ -66,8 +66,7 @@ def test_Ma_etal_section4_1b(self): eval_cmd = f'python -m pyserini.eval.trec_eval -c -M 100 -m map -m recip_rank msmarco-v2-passage-dev {output_file}' stdout, stderr = run_command(eval_cmd) score = parse_score(stdout, "recip_rank") - self.assertAlmostEqual(score, 0.1501, delta=0.0001) - # This is the score with otf; with pre-encoded, the score is 0.1499. + self.assertAlmostEqual(score, 0.1499, delta=0.0001) def test_Trotman_etal(self): """Sample code in Trotman et al. demo paper.""" diff --git a/pyserini/search/faiss/_searcher.py b/pyserini/search/faiss/_searcher.py index 828fc5ad2..23a024deb 100644 --- a/pyserini/search/faiss/_searcher.py +++ b/pyserini/search/faiss/_searcher.py @@ -418,6 +418,13 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, query_encoder: QueryEncod Searcher built from the prebuilt faiss index. """ print(f'Attempting to initialize pre-built index {prebuilt_index_name}.') + # see integrations/papers/test_sigir2021.py - preserve working commands published in papers + if prebuilt_index_name == 'msmarco-passage-tct_colbert-hnsw': + prebuilt_index_name = 'msmarco-v1-passage.tct_colbert.hnsw' + # see integrations/papers/test_ecir2023.py - preserve working commands published in papers + elif prebuilt_index_name == 'wikipedia-dpr-dkrr-nq': + prebuilt_index_name = 'wikipedia-dpr-100w.dkrr-nq' + try: index_dir = download_prebuilt_index(prebuilt_index_name) except ValueError as e: diff --git a/pyserini/search/lucene/_impact_searcher.py b/pyserini/search/lucene/_impact_searcher.py index 08d31ecef..81b551c8e 100644 --- a/pyserini/search/lucene/_impact_searcher.py +++ b/pyserini/search/lucene/_impact_searcher.py @@ -96,6 +96,10 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, query_encoder: Union[Quer Searcher built from the prebuilt index. """ print(f'Attempting to initialize pre-built index {prebuilt_index_name}.') + # see integrations/papers/test_sigir2021.py - preserve working commands published in papers + if prebuilt_index_name == 'msmarco-passage-unicoil-d2q': + prebuilt_index_name = 'msmarco-v1-passage-unicoil' + try: index_dir = download_prebuilt_index(prebuilt_index_name) except ValueError as e: diff --git a/pyserini/search/lucene/_searcher.py b/pyserini/search/lucene/_searcher.py index 45677db52..dc7eaaae6 100644 --- a/pyserini/search/lucene/_searcher.py +++ b/pyserini/search/lucene/_searcher.py @@ -69,6 +69,13 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, verbose=False): LuceneSearcher Searcher built from the prebuilt index. """ + # see integrations/papers/test_sigir2021.py - preserve working commands published in papers + if prebuilt_index_name == 'msmarco-passage': + prebuilt_index_name = 'msmarco-v1-passage' + # see integrations/papers/test_ecir2023.py - preserve working commands published in papers + elif prebuilt_index_name == 'wikipedia-dpr': + prebuilt_index_name = 'wikipedia-dpr-100w' + if verbose: print(f'Attempting to initialize pre-built index {prebuilt_index_name}.')