diff --git a/integrations/dense/test_ance.py b/integrations/dense/test_ance.py index a8c010255..1ff11aac2 100644 --- a/integrations/dense/test_ance.py +++ b/integrations/dense/test_ance.py @@ -16,8 +16,8 @@ """Integration tests for ANCE and ANCE PRF using on-the-fly query encoding.""" +import multiprocessing import os -import socket import unittest from integrations.utils import clean_files, run_command, parse_score, parse_score_qa, parse_score_msmarco @@ -29,30 +29,16 @@ class TestAnce(unittest.TestCase): def setUp(self): self.temp_files = [] self.threads = 16 - self.batch_size = 256 + self.batch_size = self.threads * 32 self.rocchio_alpha = 0.4 self.rocchio_beta = 0.6 - # Hard-code larger values for internal servers - if socket.gethostname().startswith('damiano') or socket.gethostname().startswith('orca'): - self.threads = 36 - self.batch_size = 144 - - def test_ance_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('ance-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('ance-dl19-passage') - topics = get_topics('dl19-passage') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('ance-dl20') - topics = get_topics('dl20') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 32 def test_msmarco_passage_ance_avg_prf_otf(self): output_file = 'test_run.dl2019.ance.avg-prf.otf.trec' diff --git a/integrations/dense/test_distilbert_kd.py b/integrations/dense/test_distilbert_kd.py deleted file mode 100644 index c476017df..000000000 --- a/integrations/dense/test_distilbert_kd.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Pyserini: Reproducible IR research with sparse and dense representations -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Integration tests for DistilBERT KD.""" - -import unittest - -from pyserini.search import QueryEncoder -from pyserini.search import get_topics - - -class TestDistilBertKd(unittest.TestCase): - # Note that we test actual retrieval in 2CR, so no need to test here. - def test_distilbert_kd_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('distilbert_kd-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl19-passage') - topics = get_topics('dl19-passage') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl20') - topics = get_topics('dl20') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - -if __name__ == '__main__': - unittest.main() diff --git a/integrations/dense/test_distilbert_tasb.py b/integrations/dense/test_distilbert_tasb.py deleted file mode 100644 index 997ac601b..000000000 --- a/integrations/dense/test_distilbert_tasb.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Pyserini: Reproducible IR research with sparse and dense representations -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Integration tests for DistilBERT TAS-B.""" - -import unittest - -from pyserini.search import QueryEncoder -from pyserini.search import get_topics - - -class TestDistilBertTasB(unittest.TestCase): - def test_distilbert_kd_tas_b_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl19-passage') - topics = get_topics('dl19-passage') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl20') - topics = get_topics('dl20') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - -if __name__ == '__main__': - unittest.main() diff --git a/integrations/dense/test_dpr.py b/integrations/dense/test_dpr.py index a05d94546..ec57c6e88 100644 --- a/integrations/dense/test_dpr.py +++ b/integrations/dense/test_dpr.py @@ -17,8 +17,8 @@ """Integration tests for DPR model using pre-encoded queries.""" import json +import multiprocessing import os -import socket import unittest from integrations.utils import clean_files, run_command, parse_score_qa @@ -30,12 +30,14 @@ class TestDpr(unittest.TestCase): def setUp(self): self.temp_files = [] self.threads = 16 - self.batch_size = 256 + self.batch_size = self.threads * 32 - # Hard-code larger values for internal servers - if socket.gethostname().startswith('damiano') or socket.gethostname().startswith('orca'): - self.threads = 36 - self.batch_size = 144 + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 32 def test_dpr_nq_test_bf_otf(self): output_file = 'test_run.dpr.nq-test.multi.bf.otf.trec' diff --git a/integrations/dense/test_sbert.py b/integrations/dense/test_sbert.py index 4ebab1668..884238282 100644 --- a/integrations/dense/test_sbert.py +++ b/integrations/dense/test_sbert.py @@ -16,25 +16,25 @@ """Integration tests for ANCE model using on-the-fly query encoding.""" +import multiprocessing import os -import socket import unittest from integrations.utils import clean_files, run_command, parse_score -from pyserini.search import QueryEncoder -from pyserini.search import get_topics class TestSBert(unittest.TestCase): def setUp(self): self.temp_files = [] self.threads = 16 - self.batch_size = 256 + self.batch_size = self.threads * 32 - # Hard-code larger values for internal servers - if socket.gethostname().startswith('damiano') or socket.gethostname().startswith('orca'): - self.threads = 36 - self.batch_size = 144 + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 32 def test_msmarco_passage_sbert_bf_otf(self): output_file = 'test_run.msmarco-passage.sbert.bf.otf.tsv' @@ -53,12 +53,6 @@ def test_msmarco_passage_sbert_bf_otf(self): self.assertEqual(status, 0) self.assertAlmostEqual(score, 0.3314, delta=0.0001) - def test_msmarco_passage_sbert_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('sbert-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - def tearDown(self): clean_files(self.temp_files) diff --git a/integrations/dense/test_tct_colbert-v2.py b/integrations/dense/test_tct_colbert-v2.py index 37d6fd404..4a0c0c0f9 100644 --- a/integrations/dense/test_tct_colbert-v2.py +++ b/integrations/dense/test_tct_colbert-v2.py @@ -16,25 +16,25 @@ """Integration tests for TCT-ColBERTv2 models using on-the-fly query encoding.""" +import multiprocessing import os -import socket import unittest from integrations.utils import clean_files, run_command, parse_score -from pyserini.search import QueryEncoder -from pyserini.search import get_topics class TestTctColBertV2(unittest.TestCase): def setUp(self): self.temp_files = [] self.threads = 16 - self.batch_size = 256 + self.batch_size = self.threads * 32 - # Hard-code larger values for internal servers - if socket.gethostname().startswith('damiano') or socket.gethostname().startswith('orca'): - self.threads = 36 - self.batch_size = 144 + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 32 def test_msmarco_passage_tct_colbert_v2_bf_otf(self): output_file = 'test_run.msmarco-passage.tct_colbert-v2.bf-otf.tsv' @@ -106,24 +106,6 @@ def test_msmarco_passage_tct_colbert_v2_hnp_bf_d2q_hybrid_otf(self): self.assertEqual(status, 0) self.assertAlmostEqual(score, 0.3731, delta=0.0001) - def test_msmarco_passage_tct_colbert_v2_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_passage_tct_colbert_v2_hn_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hn-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_passage_tct_colbert_v2_hnp_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hnp-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - def tearDown(self): clean_files(self.temp_files) diff --git a/integrations/dense/test_tct_colbert.py b/integrations/dense/test_tct_colbert.py index 18c9fc391..67b94d2c2 100644 --- a/integrations/dense/test_tct_colbert.py +++ b/integrations/dense/test_tct_colbert.py @@ -16,8 +16,8 @@ """Integration tests for TCT-ColBERTv1 models using on-the-fly query encoding.""" +import multiprocessing import os -import socket import unittest from integrations.utils import clean_files, run_command, parse_score @@ -29,12 +29,14 @@ class TestTctColBert(unittest.TestCase): def setUp(self): self.temp_files = [] self.threads = 16 - self.batch_size = 256 + self.batch_size = self.threads * 32 - # Hard-code larger values for internal servers - if socket.gethostname().startswith('damiano') or socket.gethostname().startswith('orca'): - self.threads = 36 - self.batch_size = 144 + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 32 def test_msmarco_passage_tct_colbert_bf_otf(self): output_file = 'test_run.msmarco-passage.tct_colbert.bf-otf.tsv' @@ -169,12 +171,6 @@ def test_msmarco_doc_tct_colbert_bf_d2q_hybrid_otf(self): self.assertEqual(status, 0) self.assertAlmostEqual(score, 0.3784, places=4) - def test_msmarco_doc_tct_colbert_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-msmarco-doc-dev') - topics = get_topics('msmarco-doc-dev') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - def tearDown(self): clean_files(self.temp_files) diff --git a/integrations/papers/test_ecir2023.py b/integrations/papers/test_ecir2023.py index f894483ed..e029c5f08 100644 --- a/integrations/papers/test_ecir2023.py +++ b/integrations/papers/test_ecir2023.py @@ -16,6 +16,7 @@ """Integration tests for commands in Pradeep et al. resource paper at ECIR 2023.""" +import multiprocessing import os import unittest @@ -25,6 +26,15 @@ class TestECIR2023(unittest.TestCase): def setUp(self): self.temp_files = [] + self.threads = 16 + self.batch_size = self.threads * 32 + + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 32 def test_section5_sub2_first(self): """Sample code of the first command in Section 5.2.""" @@ -42,7 +52,7 @@ def test_section5_sub2_first(self): --topics nq-test \ --encoder castorini/dkrr-dpr-nq-retriever \ --output {output_file} --query-prefix question: \ - --threads 72 --batch-size 72 \ + --threads {self.threads} --batch-size {self.batch_size} \ --hits 100' status = os.system(run_cmd) self.assertEqual(status, 0) diff --git a/integrations/papers/test_sigir2022.py b/integrations/papers/test_sigir2022.py index ff7ae0215..6665bb5f1 100644 --- a/integrations/papers/test_sigir2022.py +++ b/integrations/papers/test_sigir2022.py @@ -16,6 +16,7 @@ """Integration tests for commands in Ma et al. resource paper and Trotman et al. demo paper at SIGIR 2022.""" +import multiprocessing import os import unittest @@ -25,6 +26,15 @@ class TestSIGIR2022(unittest.TestCase): def setUp(self): self.temp_files = [] + self.threads = 16 + self.batch_size = self.threads * 8 + + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 8 def test_Ma_etal_section4_1a(self): """Sample code in Section 4.1. in Ma et al. resource paper.""" @@ -57,7 +67,7 @@ def test_Ma_etal_section4_1b(self): --topics msmarco-v2-passage-dev \ --encoder castorini/unicoil-msmarco-passage \ --output {output_file} \ - --batch 144 --threads 36 \ + --batch {self.batch_size} --threads {self.threads} \ --hits 1000 \ --impact' status = os.system(run_cmd) @@ -78,7 +88,7 @@ def test_Trotman_etal(self): --topics msmarco-passage-dev-subset-unicoil \ --output {output_file} \ --output-format msmarco \ - --batch 36 --threads 12 \ + --batch {self.batch_size} --threads {self.threads} \ --hits 1000 \ --impact' status = os.system(run_cmd) diff --git a/integrations/dense/test_kilt.py b/integrations/sparse/test_kilt.py similarity index 85% rename from integrations/dense/test_kilt.py rename to integrations/sparse/test_kilt.py index b337cd8ac..4b27ab1df 100644 --- a/integrations/dense/test_kilt.py +++ b/integrations/sparse/test_kilt.py @@ -16,9 +16,9 @@ """Integration tests for KILT integration.""" +import multiprocessing import os import re -import socket import unittest from integrations.utils import clean_files, run_command @@ -37,12 +37,14 @@ class TestKilt(unittest.TestCase): def setUp(self): self.temp_files = [] self.threads = 16 - self.batch_size = 256 + self.batch_size = self.threads * 8 - # Hard-code larger values for internal servers - if socket.gethostname().startswith('damiano') or socket.gethostname().startswith('orca'): - self.threads = 36 - self.batch_size = 144 + half_cores = int(multiprocessing.cpu_count() / 2) + # If server supports more threads, then use more threads. + # As a heuristic, use up half up available CPU cores. + if half_cores > self.threads: + self.threads = half_cores + self.batch_size = half_cores * 8 def test_kilt_search(self): run_file = 'test_run.fever-dev-kilt.jsonl' diff --git a/pyserini/resources/naturalquestion.yaml b/pyserini/resources/naturalquestion.yaml index e8050af03..7a6cecc0a 100644 --- a/pyserini/resources/naturalquestion.yaml +++ b/pyserini/resources/naturalquestion.yaml @@ -1,7 +1,7 @@ conditions: - model_name: BM25-k1_0.9_b_0.4 command: - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics nq-test --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics nq-test --output $output --bm25 --k1 0.9 --b 0.4 scores: - Top5: 44.82 Top20: 64.02 @@ -10,7 +10,7 @@ conditions: Top1000: 88.95 - model_name: BM25-k1_0.9_b_0.4_dpr-topics command: - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics dpr-nq-test --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics dpr-nq-test --output $output --bm25 --k1 0.9 --b 0.4 scores: - Top5: 43.77 Top20: 62.99 @@ -19,9 +19,9 @@ conditions: Top1000: 88.01 - model_name: GarT5-RRF command: - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics nq-test-gar-t5-answers --output $output --bm25 --k1 0.9 --b 0.4 - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics nq-test-gar-t5-titles --output $output --bm25 --k1 0.9 --b 0.4 - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics nq-test-gar-t5-sentences --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics nq-test-gar-t5-answers --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics nq-test-gar-t5-titles --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics nq-test-gar-t5-sentences --output $output --bm25 --k1 0.9 --b 0.4 scores: - Top5: 64.62 Top20: 77.17 @@ -30,7 +30,7 @@ conditions: Top1000: 92.91 - model_name: DPR command: - - python -m pyserini.search.faiss --threads 72 --batch-size 128 --index wikipedia-dpr-100w.dpr-single-nq --encoder facebook/dpr-question_encoder-single-nq-base --topics nq-test --output $output + - python -m pyserini.search.faiss --threads 16 --batch-size 512 --index wikipedia-dpr-100w.dpr-single-nq --encoder facebook/dpr-question_encoder-single-nq-base --topics nq-test --output $output scores: - Top5: 68.61 Top20: 80.58 @@ -39,7 +39,7 @@ conditions: Top1000: 91.83 - model_name: DPR-DKRR command: - - 'python -m pyserini.search.faiss --threads 72 --batch-size 128 --index wikipedia-dpr-100w.dkrr-nq --encoder castorini/dkrr-dpr-nq-retriever --topics nq-test --output $output --query-prefix question: ' + - 'python -m pyserini.search.faiss --threads 16 --batch-size 512 --index wikipedia-dpr-100w.dkrr-nq --encoder castorini/dkrr-dpr-nq-retriever --topics nq-test --output $output --query-prefix question: ' scores: - Top5: 73.80 Top20: 84.27 @@ -48,7 +48,7 @@ conditions: Top1000: 93.43 - model_name: DPR-Hybrid command: - - python -m pyserini.search.hybrid dense --index wikipedia-dpr-100w.dpr-single-nq --encoder facebook/dpr-question_encoder-single-nq-base sparse --index wikipedia-dpr-100w fusion --alpha 1.2 run --topics nq-test --output $output --threads 72 --batch-size 128 + - python -m pyserini.search.hybrid dense --index wikipedia-dpr-100w.dpr-single-nq --encoder facebook/dpr-question_encoder-single-nq-base sparse --index wikipedia-dpr-100w fusion --alpha 1.2 run --topics nq-test --output $output --threads 16 --batch-size 512 scores: - Top5: 72.52 Top20: 83.43 diff --git a/pyserini/resources/triviaqa.yaml b/pyserini/resources/triviaqa.yaml index a6d466a30..ea1dae6a0 100644 --- a/pyserini/resources/triviaqa.yaml +++ b/pyserini/resources/triviaqa.yaml @@ -1,7 +1,7 @@ conditions: - model_name: BM25-k1_0.9_b_0.4 command: - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics dpr-trivia-test --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics dpr-trivia-test --output $output --bm25 --k1 0.9 --b 0.4 scores: - Top5: 66.29 Top20: 76.41 @@ -10,7 +10,7 @@ conditions: Top1000: 88.50 - model_name: BM25-k1_0.9_b_0.4_dpr-topics command: - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics dpr-trivia-test --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics dpr-trivia-test --output $output --bm25 --k1 0.9 --b 0.4 scores: - Top5: 66.29 Top20: 76.41 @@ -19,9 +19,9 @@ conditions: Top1000: 88.50 - model_name: GarT5-RRF command: - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics dpr-trivia-test-gar-t5-answers --output $output --bm25 --k1 0.9 --b 0.4 - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics dpr-trivia-test-gar-t5-titles --output $output --bm25 --k1 0.9 --b 0.4 - - python -m pyserini.search.lucene --threads 72 --batch-size 128 --index wikipedia-dpr-100w --topics dpr-trivia-test-gar-t5-sentences --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics dpr-trivia-test-gar-t5-answers --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics dpr-trivia-test-gar-t5-titles --output $output --bm25 --k1 0.9 --b 0.4 + - python -m pyserini.search.lucene --threads 16 --batch-size 512 --index wikipedia-dpr-100w --topics dpr-trivia-test-gar-t5-sentences --output $output --bm25 --k1 0.9 --b 0.4 scores: - Top5: 72.82 Top20: 80.66 @@ -30,7 +30,7 @@ conditions: Top1000: 90.06 - model_name: DPR command: - - python -m pyserini.search.faiss --threads 72 --batch-size 128 --index wikipedia-dpr-100w.dpr-multi --encoder facebook/dpr-question_encoder-multiset-base --topics dpr-trivia-test --output $output + - python -m pyserini.search.faiss --threads 16 --batch-size 512 --index wikipedia-dpr-100w.dpr-multi --encoder facebook/dpr-question_encoder-multiset-base --topics dpr-trivia-test --output $output scores: - Top5: 69.80 Top20: 78.87 @@ -39,7 +39,7 @@ conditions: Top1000: 89.30 - model_name: DPR-DKRR command: - - 'python -m pyserini.search.faiss --threads 72 --batch-size 128 --index wikipedia-dpr-100w.dkrr-tqa --encoder castorini/dkrr-dpr-tqa-retriever --topics dpr-trivia-test --output $output --query-prefix question: ' + - 'python -m pyserini.search.faiss --threads 16 --batch-size 512 --index wikipedia-dpr-100w.dkrr-tqa --encoder castorini/dkrr-dpr-tqa-retriever --topics dpr-trivia-test --output $output --query-prefix question: ' scores: - Top5: 77.23 Top20: 83.74 @@ -48,7 +48,7 @@ conditions: Top1000: 90.63 - model_name: DPR-Hybrid command: - - python -m pyserini.search.hybrid dense --index wikipedia-dpr-100w.dpr-multi --encoder facebook/dpr-question_encoder-multiset-base sparse --index wikipedia-dpr-100w fusion --alpha 0.95 run --topics dpr-trivia-test --output $output --threads 72 --batch-size 128 + - python -m pyserini.search.hybrid dense --index wikipedia-dpr-100w.dpr-multi --encoder facebook/dpr-question_encoder-multiset-base sparse --index wikipedia-dpr-100w fusion --alpha 0.95 run --topics dpr-trivia-test --output $output --threads 16 --batch-size 512 scores: - Top5: 76.01 Top20: 82.64 diff --git a/pyserini/search/lucene/ltr/_search_msmarco.py b/pyserini/search/lucene/ltr/_search_msmarco.py index bc6ced6ae..7b1bb9504 100644 --- a/pyserini/search/lucene/ltr/_search_msmarco.py +++ b/pyserini/search/lucene/ltr/_search_msmarco.py @@ -205,7 +205,6 @@ def batch_extract(self, df, queries, fe): group = pd.DataFrame(group_lst, columns=['qid', 'count']) print(features.shape) print(task_infos.qid.drop_duplicates().shape) - print(group.mean()) print(features.head(10)) print(features.info()) yield task_infos, features, group @@ -219,7 +218,6 @@ def batch_extract(self, df, queries, fe): group = pd.DataFrame(group_lst, columns=['qid', 'count']) print(features.shape) print(task_infos.qid.drop_duplicates().shape) - print(group.mean()) print(features.head(10)) print(features.info()) yield task_infos, features, group diff --git a/scripts/classifier_prf/rank_trec_covid.py b/scripts/classifier_prf/rank_trec_covid.py index 73b5ae950..e879dd7c1 100644 --- a/scripts/classifier_prf/rank_trec_covid.py +++ b/scripts/classifier_prf/rank_trec_covid.py @@ -195,10 +195,10 @@ class VectorizerType(Enum): def evaluate(qrels_path: str, run_path: str, options: str = ''): curdir = os.getcwd() if curdir.endswith('clprf'): - anserini_root = '../../../anserini' + root = '..' else: - anserini_root = '../anserini' - prefix = f"{anserini_root}/tools/eval/trec_eval.9.0.4/trec_eval -c -M1000 -m all_trec {qrels_path}" + root = '.' + prefix = f"{root}/tools/eval/trec_eval.9.0.4/trec_eval -c -M1000 -m all_trec {qrels_path}" cmd1 = f"{prefix} {run_path} {options} | grep 'ndcg_cut_20 '" cmd2 = f"{prefix} {run_path} {options} | grep 'map '" ndcg_score = str(subprocess.check_output(cmd1, shell=True)).split('\\t')[-1].split('\\n')[0] diff --git a/scripts/jobs.integrations-all.txt b/scripts/jobs.integrations-all.txt new file mode 100644 index 000000000..0c8210a58 --- /dev/null +++ b/scripts/jobs.integrations-all.txt @@ -0,0 +1,4 @@ +python -m unittest discover -s integrations/dense > logs/log.dense 2>&1 +python -m unittest discover -s integrations/sparse > logs/log.sparse 2>&1 +python -m unittest discover -s integrations/clprf > logs/log.clprf 2>&1 +python -m unittest discover -s integrations/papers > logs/log.papers 2>&1 diff --git a/scripts/run_jobs_with_load.py b/scripts/run_jobs_with_load.py new file mode 100644 index 000000000..4a759d05e --- /dev/null +++ b/scripts/run_jobs_with_load.py @@ -0,0 +1,58 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import logging +import os +import time + +logger = logging.getLogger('run_jobs_with_load') +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s %(levelname)s [python] %(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Run jobs in parallel while maintaining a target load threshold.') + parser.add_argument('--file', type=str, default=None, help="File with commands.") + parser.add_argument('--sleep', type=int, default=30, help="Sleep between.") + parser.add_argument('--load', type=int, default=10, help="Target load.") + args = parser.parse_args() + + logger.info(f'Running commands in {args.file}') + logger.info(f'Sleep interval: {args.sleep}') + logger.info(f'Threshold load: {args.load}') + + with open(args.file) as f: + lines = f.read().splitlines() + + for r in lines: + if not r or r.startswith('#'): + continue + + logger.info(f'Launching: {r}') + os.system(r + ' &') + + while True: + time.sleep(args.sleep) + load = os.getloadavg()[0] + logger.info(f'Current load: {load:.1f} (threshold = {args.load})') + if load < args.load: + break diff --git a/tests/test_load_encoded_queries.py b/tests/test_load_encoded_queries.py new file mode 100644 index 000000000..721891e16 --- /dev/null +++ b/tests/test_load_encoded_queries.py @@ -0,0 +1,106 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Integration tests for DistilBERT KD.""" + +import unittest + +from pyserini.search import QueryEncoder +from pyserini.search import get_topics + + +class TestLoadEncodedQueries(unittest.TestCase): + def test_ance_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('ance-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('ance-dl19-passage') + topics = get_topics('dl19-passage') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('ance-dl20') + topics = get_topics('dl20') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_distilbert_kd_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('distilbert_kd-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl19-passage') + topics = get_topics('dl19-passage') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl20') + topics = get_topics('dl20') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_distilbert_kd_tas_b_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl19-passage') + topics = get_topics('dl19-passage') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl20') + topics = get_topics('dl20') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_doc_tct_colbert_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-msmarco-doc-dev') + topics = get_topics('msmarco-doc-dev') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_tct_colbert_v2_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_tct_colbert_v2_hn_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hn-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_tct_colbert_v2_hnp_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hnp-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_sbert_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('sbert-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_load_topics.py b/tests/test_load_queries.py similarity index 100% rename from tests/test_load_topics.py rename to tests/test_load_queries.py