diff --git a/integrations/dense/test_ance.py b/integrations-optional/dense/test_ance.py similarity index 98% rename from integrations/dense/test_ance.py rename to integrations-optional/dense/test_ance.py index 248688bfb..7dcb5174d 100644 --- a/integrations/dense/test_ance.py +++ b/integrations-optional/dense/test_ance.py @@ -20,9 +20,9 @@ import os import unittest -from integrations.utils import clean_files, run_command, parse_score, parse_score_qa, parse_score_msmarco -from pyserini.search import QueryEncoder +from integrations.utils import clean_files, run_command, parse_score_qa, parse_score_msmarco from pyserini.search import get_topics +from pyserini.search.faiss._searcher import QueryEncoder class TestAnce(unittest.TestCase): diff --git a/integrations/dense/test_dpr.py b/integrations-optional/dense/test_dpr.py similarity index 99% rename from integrations/dense/test_dpr.py rename to integrations-optional/dense/test_dpr.py index 8ce400507..f06900251 100644 --- a/integrations/dense/test_dpr.py +++ b/integrations-optional/dense/test_dpr.py @@ -22,8 +22,8 @@ import unittest from integrations.utils import clean_files, run_command, parse_score_qa -from pyserini.search import QueryEncoder from pyserini.search import get_topics +from pyserini.search.faiss._searcher import QueryEncoder class TestDpr(unittest.TestCase): diff --git a/integrations/dense/test_encode.py b/integrations-optional/dense/test_encode.py similarity index 100% rename from integrations/dense/test_encode.py rename to integrations-optional/dense/test_encode.py diff --git a/integrations/dense/test_tct_colbert-v2.py b/integrations-optional/dense/test_tct_colbert-v2.py similarity index 100% rename from integrations/dense/test_tct_colbert-v2.py rename to integrations-optional/dense/test_tct_colbert-v2.py diff --git a/integrations/dense/test_tct_colbert.py b/integrations-optional/dense/test_tct_colbert.py similarity index 99% rename from integrations/dense/test_tct_colbert.py rename to integrations-optional/dense/test_tct_colbert.py index 1610200f1..f103e1a76 100644 --- a/integrations/dense/test_tct_colbert.py +++ b/integrations-optional/dense/test_tct_colbert.py @@ -21,8 +21,8 @@ import unittest from integrations.utils import clean_files, run_command, parse_score -from pyserini.search import QueryEncoder from pyserini.search import get_topics +from pyserini.search.faiss._searcher import QueryEncoder class TestTctColBert(unittest.TestCase): diff --git a/integrations/papers/test_ecir2023.py b/integrations-optional/papers/test_ecir2023.py similarity index 100% rename from integrations/papers/test_ecir2023.py rename to integrations-optional/papers/test_ecir2023.py diff --git a/integrations/papers/test_sigir2021.py b/integrations-optional/papers/test_sigir2021.py similarity index 97% rename from integrations/papers/test_sigir2021.py rename to integrations-optional/papers/test_sigir2021.py index 10d7fe5b1..2536d2ed1 100644 --- a/integrations/papers/test_sigir2021.py +++ b/integrations-optional/papers/test_sigir2021.py @@ -22,9 +22,9 @@ from integrations.utils import clean_files, run_command, parse_score_msmarco from pyserini.dsearch import SimpleDenseSearcher, TctColBertQueryEncoder from pyserini.hsearch import HybridSearcher -from pyserini.index import IndexReader -from pyserini.search import SimpleSearcher +from pyserini.index.lucene import LuceneIndexReader from pyserini.search import get_topics, get_qrels +from pyserini.search._deprecated import SimpleSearcher class TestSIGIR2021(unittest.TestCase): @@ -100,7 +100,7 @@ def test_figure5(self): """Sample code in Figure 5.""" # Initialize from a pre-built index: - reader = IndexReader.from_prebuilt_index('robust04') + reader = LuceneIndexReader.from_prebuilt_index('robust04') terms = reader.terms() term = next(terms) diff --git a/integrations/sparse/test_lucenesearcher_check_ltr_msmarco_document.py b/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_document.py similarity index 100% rename from integrations/sparse/test_lucenesearcher_check_ltr_msmarco_document.py rename to integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_document.py diff --git a/integrations/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py b/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py similarity index 100% rename from integrations/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py rename to integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py diff --git a/integrations/sparse/test_nmslib.py b/integrations-optional/sparse/test_nmslib.py similarity index 100% rename from integrations/sparse/test_nmslib.py rename to integrations-optional/sparse/test_nmslib.py diff --git a/integrations/clprf/test_clprf.py b/integrations/clprf/test_clprf.py index 8c4595f59..49f29bf5e 100644 --- a/integrations/clprf/test_clprf.py +++ b/integrations/clprf/test_clprf.py @@ -21,7 +21,7 @@ from integrations.lucenesearcher_score_checker import LuceneSearcherScoreChecker from integrations.utils import run_command, parse_score -from pyserini.search import LuceneSearcher +from pyserini.search.lucene import LuceneSearcher class TestSearchIntegration(unittest.TestCase): diff --git a/integrations/sparse/test_lucenesearcher_check_core17.py b/integrations/sparse/test_lucenesearcher_check_core17.py index fcf7a9133..1596d26cf 100644 --- a/integrations/sparse/test_lucenesearcher_check_core17.py +++ b/integrations/sparse/test_lucenesearcher_check_core17.py @@ -17,7 +17,7 @@ import unittest from integrations.lucenesearcher_anserini_checker import LuceneSearcherAnseriniMatchChecker -from pyserini.search import LuceneSearcher +from pyserini.search.lucene import LuceneSearcher class CheckSearchResultsAgainstAnseriniForCore17(unittest.TestCase): diff --git a/integrations/sparse/test_lucenesearcher_check_core18.py b/integrations/sparse/test_lucenesearcher_check_core18.py index 640b4c0fa..9b8a18a03 100644 --- a/integrations/sparse/test_lucenesearcher_check_core18.py +++ b/integrations/sparse/test_lucenesearcher_check_core18.py @@ -17,7 +17,7 @@ import unittest from integrations.lucenesearcher_anserini_checker import LuceneSearcherAnseriniMatchChecker -from pyserini.search import LuceneSearcher +from pyserini.search.lucene import LuceneSearcher class CheckSearchResultsAgainstAnseriniForCore18(unittest.TestCase): diff --git a/integrations/sparse/test_lucenesearcher_check_robust04.py b/integrations/sparse/test_lucenesearcher_check_robust04.py index 24b5a9bda..9be1a10f0 100644 --- a/integrations/sparse/test_lucenesearcher_check_robust04.py +++ b/integrations/sparse/test_lucenesearcher_check_robust04.py @@ -17,7 +17,7 @@ import unittest from integrations.lucenesearcher_anserini_checker import LuceneSearcherAnseriniMatchChecker -from pyserini.search import LuceneSearcher +from pyserini.search.lucene import LuceneSearcher class CheckSearchResultsAgainstAnseriniForRobust04(unittest.TestCase): diff --git a/integrations/sparse/test_lucenesearcher_check_robust05.py b/integrations/sparse/test_lucenesearcher_check_robust05.py index 268db0967..e82e33cec 100644 --- a/integrations/sparse/test_lucenesearcher_check_robust05.py +++ b/integrations/sparse/test_lucenesearcher_check_robust05.py @@ -17,7 +17,7 @@ import unittest from integrations.lucenesearcher_anserini_checker import LuceneSearcherAnseriniMatchChecker -from pyserini.search import LuceneSearcher +from pyserini.search.lucene import LuceneSearcher class CheckSearchResultsAgainstAnseriniForRobust05(unittest.TestCase): diff --git a/integrations/sparse/test_simple_fusion_search_integration.py b/integrations/sparse/test_simple_fusion_search_integration.py index 5d377a451..8fc4c850d 100644 --- a/integrations/sparse/test_simple_fusion_search_integration.py +++ b/integrations/sparse/test_simple_fusion_search_integration.py @@ -24,7 +24,7 @@ from pyserini.fusion import FusionMethod from pyserini.search import get_topics -from pyserini.search import LuceneFusionSearcher +from pyserini.search.lucene import LuceneFusionSearcher from pyserini.trectools import TrecRun from pyserini.util import download_url, download_and_unpack_index diff --git a/pyserini/2cr/atomic.py b/pyserini/2cr/atomic.py index a7fac6ba2..b38c884e3 100644 --- a/pyserini/2cr/atomic.py +++ b/pyserini/2cr/atomic.py @@ -1,14 +1,31 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import argparse +import importlib.resources +import math import os import sys +import time from collections import defaultdict from string import Template -import importlib.resources -import time + import yaml -import math -from ._base import run_eval_and_return_metric, ok_str, fail_str +from ._base import run_eval_and_return_metric, ok_str, fail_str atomic_models = [ 'ViT-L-14.laion2b_s32b_b82k', diff --git a/pyserini/2cr/beir.py b/pyserini/2cr/beir.py index 7f093e380..a335b538f 100644 --- a/pyserini/2cr/beir.py +++ b/pyserini/2cr/beir.py @@ -15,6 +15,7 @@ # import argparse +import importlib.resources import math import os import sys @@ -23,7 +24,6 @@ from datetime import datetime from string import Template -import importlib.resources import yaml from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str diff --git a/pyserini/2cr/ciral.py b/pyserini/2cr/ciral.py index f446e7f03..9d70aeda1 100644 --- a/pyserini/2cr/ciral.py +++ b/pyserini/2cr/ciral.py @@ -15,18 +15,18 @@ # import argparse +import importlib.resources import math import os import sys import time -import importlib.resources from collections import defaultdict, OrderedDict from datetime import datetime from string import Template import yaml -from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str +from ._base import run_eval_and_return_metric, ok_str, fail_str dense_threads = 16 dense_batch_size = 512 diff --git a/pyserini/2cr/miracl.py b/pyserini/2cr/miracl.py index b88d4600d..b0a63bc06 100644 --- a/pyserini/2cr/miracl.py +++ b/pyserini/2cr/miracl.py @@ -15,6 +15,7 @@ # import argparse +import importlib.resources import math import os import subprocess @@ -24,7 +25,6 @@ from datetime import datetime from string import Template -import importlib.resources import yaml from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str diff --git a/pyserini/2cr/mrtydi.py b/pyserini/2cr/mrtydi.py index cd1f31b56..76cab7c98 100644 --- a/pyserini/2cr/mrtydi.py +++ b/pyserini/2cr/mrtydi.py @@ -15,6 +15,7 @@ # import argparse +import importlib.resources import math import os import sys @@ -23,7 +24,6 @@ from datetime import datetime from string import Template -import importlib.resources import yaml from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str diff --git a/pyserini/2cr/msmarco.py b/pyserini/2cr/msmarco.py index 2bb464f85..9bf3c356d 100644 --- a/pyserini/2cr/msmarco.py +++ b/pyserini/2cr/msmarco.py @@ -15,6 +15,7 @@ # import argparse +import importlib.resources import math import os import re @@ -24,7 +25,6 @@ from datetime import datetime from string import Template -import importlib.resources import yaml from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str diff --git a/pyserini/2cr/odqa.py b/pyserini/2cr/odqa.py index 2bc8021b8..f810bf5da 100644 --- a/pyserini/2cr/odqa.py +++ b/pyserini/2cr/odqa.py @@ -15,6 +15,7 @@ # import argparse +import importlib.resources import math import os import sys @@ -23,7 +24,6 @@ from datetime import datetime from string import Template -import importlib.resources import yaml from ._base import run_dpr_retrieval_eval_and_return_metric, convert_trec_run_to_dpr_retrieval_json, run_fusion, ok_str, \ diff --git a/pyserini/analysis/__init__.py b/pyserini/analysis/__init__.py index d3eed751b..0939ae7a7 100644 --- a/pyserini/analysis/__init__.py +++ b/pyserini/analysis/__init__.py @@ -15,5 +15,3 @@ # from ._base import get_lucene_analyzer, Analyzer, JAnalyzer, JAnalyzerUtils, JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer - -__all__ = ['get_lucene_analyzer', 'Analyzer', 'JAnalyzer', 'JAnalyzerUtils', 'JDefaultEnglishAnalyzer', 'JWhiteSpaceAnalyzer'] diff --git a/pyserini/analysis/_base.py b/pyserini/analysis/_base.py index 7ca17c5ec..2c7b4d1e6 100644 --- a/pyserini/analysis/_base.py +++ b/pyserini/analysis/_base.py @@ -16,7 +16,7 @@ from typing import List -from ..pyclass import autoclass +from pyserini.pyclass import autoclass # Wrappers around Lucene classes JAnalyzer = autoclass('org.apache.lucene.analysis.Analyzer') diff --git a/pyserini/collection/__init__.py b/pyserini/collection/__init__.py index 464516a63..2687ca1c1 100644 --- a/pyserini/collection/__init__.py +++ b/pyserini/collection/__init__.py @@ -16,5 +16,3 @@ from ._base import Collection, FileSegment, SourceDocument from ._collection_support import Cord19Article - -__all__ = ['Collection', 'FileSegment', 'SourceDocument', 'Cord19Article'] diff --git a/pyserini/demo/acl.py b/pyserini/demo/acl.py index 57d0b6f7d..4a0e12542 100644 --- a/pyserini/demo/acl.py +++ b/pyserini/demo/acl.py @@ -23,14 +23,16 @@ --port [PORT] --hits [Number of hits] --k1 [BM25 k1] --b [BM25 b] --device [cpu, cuda] """ -import json + import logging from argparse import ArgumentParser from functools import partial from typing import Callable, Optional, Tuple, Union -from flask import Flask, render_template, request, flash, jsonify -from pyserini.search import LuceneSearcher, FaissSearcher, AutoQueryEncoder +from flask import Flask, render_template, request, flash + +from pyserini.search.faiss import FaissSearcher +from pyserini.search.lucene import LuceneSearcher logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', diff --git a/pyserini/demo/atomic.py b/pyserini/demo/atomic.py index b9cb53488..7d45ebe17 100644 --- a/pyserini/demo/atomic.py +++ b/pyserini/demo/atomic.py @@ -23,14 +23,16 @@ --port [PORT] --hits [Number of hits] --index [BM25 or {dense retrieval flag}] --k1 [BM25 k1] --b [BM25 b] --device [cpu, cuda] """ + import json from argparse import ArgumentParser from functools import partial from typing import Callable, Optional, Tuple, Union from flask import Flask, render_template, request, flash, jsonify -from pyserini.search import LuceneSearcher, FaissSearcher +from pyserini.search.faiss import FaissSearcher +from pyserini.search.lucene import LuceneSearcher INDEX_NAMES = ( 'atomic_image_v0.2_small_validation', diff --git a/pyserini/demo/dpr.py b/pyserini/demo/dpr.py index 02e9aca42..e008fffc1 100644 --- a/pyserini/demo/dpr.py +++ b/pyserini/demo/dpr.py @@ -18,15 +18,15 @@ import json import random -from pyserini.search.lucene import LuceneSearcher +from pyserini.search import get_topics from pyserini.search.faiss import FaissSearcher, DprQueryEncoder from pyserini.search.hybrid import HybridSearcher -from pyserini import search +from pyserini.search.lucene import LuceneSearcher class DPRDemo(cmd.Cmd): - nq_dev_topics = list(search.get_topics('dpr-nq-dev').values()) - trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values()) + nq_dev_topics = list(get_topics('dpr-nq-dev').values()) + trivia_dev_topics = list(get_topics('dpr-trivia-dev').values()) ssearcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr') searcher = ssearcher diff --git a/pyserini/demo/miracl.py b/pyserini/demo/miracl.py index ffecb93f2..c6e4b2c10 100644 --- a/pyserini/demo/miracl.py +++ b/pyserini/demo/miracl.py @@ -23,6 +23,7 @@ --port [PORT] --hits [Number of hits] --index [BM25 or mdpr-tied-pft-msmarco] --k1 [BM25 k1] --b [BM25 b] --device [cpu, cuda] """ + import json import logging from argparse import ArgumentParser @@ -30,7 +31,10 @@ from typing import Callable, Optional, Tuple, Union from flask import Flask, render_template, request, flash, jsonify -from pyserini.search import LuceneSearcher, FaissSearcher, AutoQueryEncoder + +from pyserini.encode import AutoQueryEncoder +from pyserini.search.faiss import FaissSearcher +from pyserini.search.lucene import LuceneSearcher logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', diff --git a/pyserini/demo/msmarco.py b/pyserini/demo/msmarco.py index b73276d1c..4482da97a 100644 --- a/pyserini/demo/msmarco.py +++ b/pyserini/demo/msmarco.py @@ -16,17 +16,16 @@ import cmd import json -import os import random -from pyserini.search.lucene import LuceneSearcher +from pyserini.search import get_topics from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder from pyserini.search.hybrid import HybridSearcher -from pyserini import search +from pyserini.search.lucene import LuceneSearcher class MsMarcoDemo(cmd.Cmd): - dev_topics = list(search.get_topics('msmarco-passage-dev-subset').values()) + dev_topics = list(get_topics('msmarco-passage-dev-subset').values()) ssearcher = LuceneSearcher.from_prebuilt_index('msmarco-passage') dsearcher = None diff --git a/pyserini/dsearch.py b/pyserini/dsearch.py index 72947e9e5..b65956aba 100644 --- a/pyserini/dsearch.py +++ b/pyserini/dsearch.py @@ -20,20 +20,20 @@ import os import sys -import pyserini.search.faiss -from pyserini.search.faiss import TctColBertQueryEncoder +from pyserini.search.faiss import FaissSearcher +from pyserini.search.faiss._searcher import TctColBertQueryEncoder, BinaryDenseSearcher __all__ = ['SimpleDenseSearcher', 'BinaryDenseSearcher', 'TctColBertQueryEncoder'] -class SimpleDenseSearcher(pyserini.search.faiss.FaissSearcher): +class SimpleDenseSearcher(FaissSearcher): def __new__(cls, *args, **kwargs): print('pyserini.dsearch.SimpleDenseSearcher class has been deprecated, ' 'please use FaissSearcher from pyserini.search.faiss instead') return super().__new__(cls) -class BinaryDenseSearcher(pyserini.search.faiss.BinaryDenseSearcher): +class BinaryDenseSearcher(BinaryDenseSearcher): def __new__(cls, *args, **kwargs): print('pyserini.dsearch.BinaryDenseSearcher class has been deprecated, ' 'please use BinaryDenseSearcher from pyserini.search.faiss instead') diff --git a/pyserini/encode/__init__.py b/pyserini/encode/__init__.py index c9e28ee5e..841e34800 100644 --- a/pyserini/encode/__init__.py +++ b/pyserini/encode/__init__.py @@ -14,18 +14,19 @@ # limitations under the License. # -from ._base import DocumentEncoder, QueryEncoder, JsonlCollectionIterator,\ - RepresentationWriter, FaissRepresentationWriter, JsonlRepresentationWriter, PcaEncoder +# This has to be first, otherwise we'll get circular import errors +from ._base import QueryEncoder, DocumentEncoder, JsonlCollectionIterator, JsonlRepresentationWriter + +# Then import these... +from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder from ._auto import AutoQueryEncoder, AutoDocumentEncoder +from ._cached_data import CachedDataQueryEncoder +from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder from ._dpr import DprDocumentEncoder, DprQueryEncoder +from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY +from ._slim import SlimQueryEncoder +from ._splade import SpladeQueryEncoder from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder -from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder -from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder -from ._cached_data import CachedDataQueryEncoder from ._tok_freq import TokFreqQueryEncoder -from ._splade import SpladeQueryEncoder -from ._slim import SlimQueryEncoder -from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY -from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder -from ._clip import ClipEncoder, ClipDocumentEncoder \ No newline at end of file +from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder diff --git a/pyserini/encode/__main__.py b/pyserini/encode/__main__.py index e7c8b3d78..37643b40e 100644 --- a/pyserini/encode/__main__.py +++ b/pyserini/encode/__main__.py @@ -17,11 +17,11 @@ import argparse import sys -from pyserini.encode import JsonlRepresentationWriter, FaissRepresentationWriter, JsonlCollectionIterator -from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, AggretrieverDocumentEncoder, AutoDocumentEncoder, CosDprDocumentEncoder, ClipDocumentEncoder -from pyserini.encode import UniCoilDocumentEncoder -from pyserini.encode import OpenAIDocumentEncoder, OPENAI_API_RETRY_DELAY - +from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, \ + AggretrieverDocumentEncoder, AutoDocumentEncoder, CosDprDocumentEncoder, JsonlRepresentationWriter, \ + JsonlCollectionIterator, UniCoilDocumentEncoder, OpenAIDocumentEncoder, OPENAI_API_RETRY_DELAY +from pyserini.encode._clip import ClipDocumentEncoder +from pyserini.encode._faiss import FaissRepresentationWriter encoder_class_map = { "dpr": DprDocumentEncoder, diff --git a/pyserini/encode/_aggretriever.py b/pyserini/encode/_aggretriever.py index 224eb2b05..37268bca5 100644 --- a/pyserini/encode/_aggretriever.py +++ b/pyserini/encode/_aggretriever.py @@ -15,10 +15,11 @@ # from typing import Optional -import numpy as np + import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor + if torch.cuda.is_available(): from torch.cuda.amp import autocast diff --git a/pyserini/encode/_base.py b/pyserini/encode/_base.py index 15e4cdb65..47457ace4 100644 --- a/pyserini/encode/_base.py +++ b/pyserini/encode/_base.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import json import os -import faiss -import torch import numpy as np +import torch from tqdm import tqdm @@ -40,22 +40,6 @@ def encode(self, text, **kwargs): pass -class PcaEncoder: - def __init__(self, encoder, pca_model_path): - self.encoder = encoder - self.pca_mat = faiss.read_VectorTransform(pca_model_path) - - def encode(self, text, **kwargs): - if isinstance(text, str): - embeddings = self.encoder.encode(text, **kwargs) - embeddings = self.pca_mat.apply_py(np.array([embeddings])) - embeddings = embeddings[0] - else: - embeddings = self.encoder.encode(text, **kwargs) - embeddings = self.pca_mat.apply_py(embeddings) - return embeddings - - class JsonlCollectionIterator: def __init__(self, collection_path: str, fields=None, docid_field=None, delimiter="\n"): # Assume multimodal input files are located in the same directory as the collection file @@ -190,27 +174,3 @@ def write(self, batch_info, fields=None): self.file.write(json.dumps({'id': batch_info['id'][i], 'contents': contents, 'vector': vector}) + '\n') - - -class FaissRepresentationWriter(RepresentationWriter): - def __init__(self, dir_path, dimension=768): - self.dir_path = dir_path - self.index_name = 'index' - self.id_file_name = 'docid' - self.dimension = dimension - self.index = faiss.IndexFlatIP(self.dimension) - self.id_file = None - - def __enter__(self): - if not os.path.exists(self.dir_path): - os.makedirs(self.dir_path) - self.id_file = open(os.path.join(self.dir_path, self.id_file_name), 'w') - - def __exit__(self, exc_type, exc_val, exc_tb): - self.id_file.close() - faiss.write_index(self.index, os.path.join(self.dir_path, self.index_name)) - - def write(self, batch_info, fields=None): - for id_ in batch_info['id']: - self.id_file.write(f'{id_}\n') - self.index.add(np.ascontiguousarray(batch_info['vector'])) diff --git a/pyserini/encode/_clip.py b/pyserini/encode/_clip.py index 648ede802..ce91c4f18 100644 --- a/pyserini/encode/_clip.py +++ b/pyserini/encode/_clip.py @@ -1,18 +1,33 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import os -import requests from io import BytesIO + +import requests +import torch from PIL import Image, ImageOps from sklearn.preprocessing import normalize from transformers import CLIPProcessor, CLIPModel -import torch from pyserini.encode import DocumentEncoder, QueryEncoder os.environ['OPENBLAS_NUM_THREADS'] = '1' - - def load_pil_image(image, format='RGB'): if isinstance(image, str) or os.path.isfile(image): if image.startswith(("http://", "https://")): # Image is a URL diff --git a/pyserini/encode/_faiss.py b/pyserini/encode/_faiss.py new file mode 100644 index 000000000..673b8f8b6 --- /dev/null +++ b/pyserini/encode/_faiss.py @@ -0,0 +1,46 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import faiss +import numpy as np + +from pyserini.encode._base import RepresentationWriter + + +class FaissRepresentationWriter(RepresentationWriter): + def __init__(self, dir_path, dimension=768): + self.dir_path = dir_path + self.index_name = 'index' + self.id_file_name = 'docid' + self.dimension = dimension + self.index = faiss.IndexFlatIP(self.dimension) + self.id_file = None + + def __enter__(self): + if not os.path.exists(self.dir_path): + os.makedirs(self.dir_path) + self.id_file = open(os.path.join(self.dir_path, self.id_file_name), 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + self.id_file.close() + faiss.write_index(self.index, os.path.join(self.dir_path, self.index_name)) + + def write(self, batch_info, fields=None): + for id_ in batch_info['id']: + self.id_file.write(f'{id_}\n') + self.index.add(np.ascontiguousarray(batch_info['vector'])) diff --git a/pyserini/encode/_openai.py b/pyserini/encode/_openai.py index a06410f99..78e6f3561 100644 --- a/pyserini/encode/_openai.py +++ b/pyserini/encode/_openai.py @@ -1,10 +1,28 @@ -import openai -from typing import List +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import os import time -from pyserini.encode import DocumentEncoder, QueryEncoder -import tiktoken +from typing import List + import numpy as np +import openai +import tiktoken + +from pyserini.encode import DocumentEncoder, QueryEncoder api_key = '' if os.getenv("OPENAI_API_KEY") is None else os.getenv("OPENAI_API_KEY") org_key = '' if os.getenv("OPENAI_ORG_KEY") is None else os.getenv("OPENAI_ORG_KEY") diff --git a/pyserini/encode/_pca.py b/pyserini/encode/_pca.py new file mode 100644 index 000000000..1e8e2c967 --- /dev/null +++ b/pyserini/encode/_pca.py @@ -0,0 +1,34 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import faiss +import numpy as np + + +class PcaEncoder: + def __init__(self, encoder, pca_model_path): + self.encoder = encoder + self.pca_mat = faiss.read_VectorTransform(pca_model_path) + + def encode(self, text, **kwargs): + if isinstance(text, str): + embeddings = self.encoder.encode(text, **kwargs) + embeddings = self.pca_mat.apply_py(np.array([embeddings])) + embeddings = embeddings[0] + else: + embeddings = self.encoder.encode(text, **kwargs) + embeddings = self.pca_mat.apply_py(embeddings) + return embeddings diff --git a/pyserini/encode/_slim.py b/pyserini/encode/_slim.py index 76687e376..213f41be9 100644 --- a/pyserini/encode/_slim.py +++ b/pyserini/encode/_slim.py @@ -1,7 +1,22 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import scipy import torch from transformers import AutoModelForMaskedLM, AutoTokenizer -import numpy as np -import scipy from pyserini.encode import QueryEncoder diff --git a/pyserini/encode/_splade.py b/pyserini/encode/_splade.py index 23312c6d5..a156ceb73 100644 --- a/pyserini/encode/_splade.py +++ b/pyserini/encode/_splade.py @@ -1,6 +1,22 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np import torch from transformers import AutoModelForMaskedLM, AutoTokenizer -import numpy as np from pyserini.encode import QueryEncoder diff --git a/pyserini/encode/_tct_colbert.py b/pyserini/encode/_tct_colbert.py index e68dc7d42..f7ef51bdc 100644 --- a/pyserini/encode/_tct_colbert.py +++ b/pyserini/encode/_tct_colbert.py @@ -21,7 +21,7 @@ from transformers import BertModel, BertTokenizer, BertTokenizerFast from pyserini.encode import DocumentEncoder, QueryEncoder -from onnxruntime import ExecutionMode, SessionOptions, InferenceSession +from onnxruntime import SessionOptions, InferenceSession class TctColBertDocumentEncoder(DocumentEncoder): diff --git a/pyserini/encode/_unicoil.py b/pyserini/encode/_unicoil.py index 1f60f14ab..94abde56e 100644 --- a/pyserini/encode/_unicoil.py +++ b/pyserini/encode/_unicoil.py @@ -17,6 +17,7 @@ from typing import Optional import torch + if torch.cuda.is_available(): from torch.cuda.amp import autocast from transformers import BertConfig, BertModel, BertTokenizer, PreTrainedModel diff --git a/pyserini/encode/merge_faiss_index.py b/pyserini/encode/merge_faiss_index.py index bc0a929cd..9767c0d72 100644 --- a/pyserini/encode/merge_faiss_index.py +++ b/pyserini/encode/merge_faiss_index.py @@ -14,14 +14,13 @@ # limitations under the License. # -import os -import glob import argparse +import glob +import os import faiss from tqdm import tqdm - parser = argparse.ArgumentParser() parser.add_argument('--dimension', type=int, help='dimension of passage embeddings', required=False, default=768) parser.add_argument('--input', type=str, help='wildcard directory to input indexes', required=True) diff --git a/pyserini/encode/query.py b/pyserini/encode/query.py index 0dd768b7b..bdbaf5a24 100644 --- a/pyserini/encode/query.py +++ b/pyserini/encode/query.py @@ -16,12 +16,13 @@ import argparse -from tqdm import tqdm import numpy as np import pandas as pd +from tqdm import tqdm + +from pyserini.encode import DprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AutoQueryEncoder, \ + UniCoilQueryEncoder, SpladeQueryEncoder, OpenAIQueryEncoder, CosDprQueryEncoder from pyserini.query_iterator import DefaultQueryIterator -from pyserini.encode import DprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AutoQueryEncoder -from pyserini.encode import UniCoilQueryEncoder, SpladeQueryEncoder, OpenAIQueryEncoder, CosDprQueryEncoder def init_encoder(encoder, device, pooling, l2_norm, prefix): diff --git a/pyserini/eval/convert_trec_run_to_dpr_retrieval_run.py b/pyserini/eval/convert_trec_run_to_dpr_retrieval_run.py index 7ef59efd2..6e2eead7e 100644 --- a/pyserini/eval/convert_trec_run_to_dpr_retrieval_run.py +++ b/pyserini/eval/convert_trec_run_to_dpr_retrieval_run.py @@ -17,11 +17,12 @@ import argparse import json import os + from tqdm import tqdm +from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer from pyserini.search import get_topics, get_topics_with_reader from pyserini.search.lucene import LuceneSearcher -from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer if __name__ == '__main__': parser = argparse.ArgumentParser(description='Convert an TREC run to DPR retrieval result json.') diff --git a/pyserini/eval/evaluate_dpr_retrieval.py b/pyserini/eval/evaluate_dpr_retrieval.py index e494cd7bc..6aeb096e8 100644 --- a/pyserini/eval/evaluate_dpr_retrieval.py +++ b/pyserini/eval/evaluate_dpr_retrieval.py @@ -24,10 +24,10 @@ import logging import re import unicodedata -from tqdm import tqdm -import numpy as np +import numpy as np import regex +from tqdm import tqdm logger = logging.getLogger(__name__) diff --git a/pyserini/eval/evaluate_kilt_retrieval.py b/pyserini/eval/evaluate_kilt_retrieval.py index 6de8ac2ec..eef55fdf3 100644 --- a/pyserini/eval/evaluate_kilt_retrieval.py +++ b/pyserini/eval/evaluate_kilt_retrieval.py @@ -2,11 +2,11 @@ # https://github.com/facebookresearch/KILT/blob/9bcb119a7ed5fda88826058b062d0e45c726c676/kilt/eval_retrieval.py import argparse -import pprint import json +import os +import pprint from collections import defaultdict, OrderedDict -import os from pyserini.query_iterator import KiltQueryIterator diff --git a/pyserini/eval/evaluate_qa_overlap_retrieval.py b/pyserini/eval/evaluate_qa_overlap_retrieval.py index dd5163922..dd86559cb 100644 --- a/pyserini/eval/evaluate_qa_overlap_retrieval.py +++ b/pyserini/eval/evaluate_qa_overlap_retrieval.py @@ -19,16 +19,16 @@ """ import argparse +import collections import copy import json import logging +import os import re import unicodedata -from tqdm import tqdm + import numpy as np -import os import regex -import collections logger = logging.getLogger(__name__) diff --git a/pyserini/eval/msmarco_doc_eval.py b/pyserini/eval/msmarco_doc_eval.py index 4b818cde3..f5f51ab6a 100644 --- a/pyserini/eval/msmarco_doc_eval.py +++ b/pyserini/eval/msmarco_doc_eval.py @@ -15,9 +15,9 @@ # import os +import platform import subprocess import sys -import platform from pyserini.search import get_qrels_file from pyserini.util import download_evaluation_script diff --git a/pyserini/eval/msmarco_passage_eval.py b/pyserini/eval/msmarco_passage_eval.py index c5a07f950..d84af12c1 100644 --- a/pyserini/eval/msmarco_passage_eval.py +++ b/pyserini/eval/msmarco_passage_eval.py @@ -15,9 +15,9 @@ # import os +import platform import subprocess import sys -import platform from pyserini.search import get_qrels_file from pyserini.util import download_evaluation_script diff --git a/pyserini/eval/trec_eval.py b/pyserini/eval/trec_eval.py index 4e11ab904..63ffe2142 100644 --- a/pyserini/eval/trec_eval.py +++ b/pyserini/eval/trec_eval.py @@ -31,13 +31,14 @@ import glob import importlib.resources -import jnius_config import os -import pandas as pd import platform -import tempfile import subprocess import sys +import tempfile + +import jnius_config +import pandas as pd # Don't use the jdk.incubator.vector module. jar_directory = str(importlib.resources.files("pyserini.resources.jars").joinpath('')) @@ -45,9 +46,10 @@ jnius_config.add_classpath(jar_path) # This triggers loading of the JVM. -from jnius import autoclass +import jnius -# Now we can load qrels +# Now we can load qrels; this will trigger another attempt to reload the JVM, which won't happen because +# the JVM has already loaded. from pyserini.search import get_qrels_file cmd_prefix = ['java', '-cp', jar_path, 'trec_eval'] diff --git a/pyserini/fusion/__init__.py b/pyserini/fusion/__init__.py index 6eff3bfce..da4e22842 100644 --- a/pyserini/fusion/__init__.py +++ b/pyserini/fusion/__init__.py @@ -15,5 +15,3 @@ # from ._base import average, FusionMethod, interpolation, reciprocal_rank_fusion - -__all__ = ['FusionMethod', 'average', 'interpolation', 'reciprocal_rank_fusion'] diff --git a/pyserini/fusion/__main__.py b/pyserini/fusion/__main__.py index 640754ec9..582fd56e6 100644 --- a/pyserini/fusion/__main__.py +++ b/pyserini/fusion/__main__.py @@ -15,11 +15,11 @@ # import argparse -from ._base import FusionMethod + from pyserini.fusion import average, interpolation, reciprocal_rank_fusion +from ._base import FusionMethod from ..trectools import TrecRun - parser = argparse.ArgumentParser(description='Perform various ways of fusion given a list of trec run files.') parser.add_argument('--runs', type=str, nargs='+', default=[], required=True, help='List of run files separated by space.') diff --git a/pyserini/fusion/_base.py b/pyserini/fusion/_base.py index 674fc07e9..7b076fa9f 100644 --- a/pyserini/fusion/_base.py +++ b/pyserini/fusion/_base.py @@ -15,9 +15,10 @@ # from enum import Enum -from pyserini.trectools import AggregationMethod, RescoreMethod, TrecRun from typing import List +from pyserini.trectools import AggregationMethod, RescoreMethod, TrecRun + class FusionMethod(Enum): RRF = 'rrf' diff --git a/pyserini/index/__init__.py b/pyserini/index/__init__.py index 4120db77a..32dc65117 100644 --- a/pyserini/index/__init__.py +++ b/pyserini/index/__init__.py @@ -16,8 +16,6 @@ # Classes here have been moved to pyserini.index.lucene, e.g., the pyserini.index.Indexer is now # pyserini.index.lucene.IndexReader. We're importing symbols here and then re-exporting to preserve -# backward compatability to code snippets published in Lin et al. (SIGIR 2021). +# backward compatibility to code snippets published in Lin et al. (SIGIR 2021). -from .lucene._base import Document, Generator, IndexTerm, Posting, IndexReader - -__all__ = ['Document', 'Generator', 'IndexTerm', 'Posting', 'IndexReader'] +from .lucene._base import Document, Generator, IndexTerm, Posting, LuceneIndexReader diff --git a/pyserini/index/__main__.py b/pyserini/index/__main__.py index ecce08907..141693b0f 100644 --- a/pyserini/index/__main__.py +++ b/pyserini/index/__main__.py @@ -14,9 +14,10 @@ # limitations under the License. # -from jnius import autoclass -import sys import os +import sys + +from jnius import autoclass print('pyserini.index is deprecated, please use pyserini.index.lucene.') args = sys.argv[1:] diff --git a/pyserini/index/faiss.py b/pyserini/index/faiss.py index ba9ca14c2..2bf5b8a87 100644 --- a/pyserini/index/faiss.py +++ b/pyserini/index/faiss.py @@ -14,13 +14,13 @@ # limitations under the License. # +import argparse import json import os -import argparse import shutil -import numpy as np import faiss +import numpy as np from tqdm import tqdm if __name__ == '__main__': diff --git a/pyserini/index/lucene/__init__.py b/pyserini/index/lucene/__init__.py index 792675301..b88260543 100644 --- a/pyserini/index/lucene/__init__.py +++ b/pyserini/index/lucene/__init__.py @@ -14,8 +14,5 @@ # limitations under the License. # -from ._base import Document, Generator, IndexTerm, Posting, IndexReader +from ._base import Document, Generator, IndexTerm, Posting, LuceneIndexReader from ._indexer import LuceneIndexer, JacksonObjectMapper, JacksonJsonNode - -__all__ = ['Document', 'Generator', 'IndexTerm', 'Posting', 'IndexReader', 'LuceneIndexer', - 'JacksonObjectMapper', 'JacksonJsonNode'] \ No newline at end of file diff --git a/pyserini/index/lucene/__main__.py b/pyserini/index/lucene/__main__.py index 01d5a05cb..c4306abfd 100644 --- a/pyserini/index/lucene/__main__.py +++ b/pyserini/index/lucene/__main__.py @@ -14,10 +14,10 @@ # limitations under the License. # -from jnius import autoclass -import sys import os +import sys +from jnius import autoclass if __name__ == '__main__': args = sys.argv[1:] diff --git a/pyserini/index/lucene/_base.py b/pyserini/index/lucene/_base.py index 9c5dba095..d71ed5693 100644 --- a/pyserini/index/lucene/_base.py +++ b/pyserini/index/lucene/_base.py @@ -20,17 +20,18 @@ and methods provided are meant only to provide tools for examining an index and are not optimized for computing over. """ +import json import logging +import math from enum import Enum from typing import Dict, Iterator, List, Optional, Tuple + from tqdm import tqdm -import json -import math from pyserini.analysis import get_lucene_analyzer, JAnalyzer, JAnalyzerUtils +from pyserini.prebuilt_index_info import TF_INDEX_INFO, IMPACT_INDEX_INFO from pyserini.pyclass import autoclass from pyserini.util import download_prebuilt_index, get_sparse_indexes_info -from pyserini.prebuilt_index_info import TF_INDEX_INFO, IMPACT_INDEX_INFO logger = logging.getLogger(__name__) @@ -180,7 +181,7 @@ def __repr__(self): return repr -class IndexReader: +class LuceneIndexReader: """Wrapper class for ``IndexReaderUtils`` in Anserini. Parameters diff --git a/pyserini/index/merge_faiss_indexes.py b/pyserini/index/merge_faiss_indexes.py index 5662aae9f..5d930eb6f 100644 --- a/pyserini/index/merge_faiss_indexes.py +++ b/pyserini/index/merge_faiss_indexes.py @@ -15,10 +15,9 @@ # import argparse - -import faiss import os +import faiss parser = argparse.ArgumentParser() parser.add_argument('--dimension', type=int, help='dimension of passage embeddings', required=False, default=768) diff --git a/pyserini/output_writer.py b/pyserini/output_writer.py index e484a08d1..11e051141 100644 --- a/pyserini/output_writer.py +++ b/pyserini/output_writer.py @@ -16,12 +16,11 @@ import json import os - from abc import ABC, abstractmethod from enum import Enum, unique from typing import List -from pyserini.search import JScoredDoc +from pyserini.search.lucene import JScoredDoc @unique diff --git a/pyserini/query_iterator.py b/pyserini/query_iterator.py index 42d0918fd..cae6b19f5 100644 --- a/pyserini/query_iterator.py +++ b/pyserini/query_iterator.py @@ -14,16 +14,16 @@ # limitations under the License. # -import os import json +import os from abc import ABC, abstractmethod from enum import Enum, unique from pathlib import Path +from urllib.error import HTTPError, URLError +from pyserini.external_query_info import KILT_QUERY_INFO from pyserini.search import get_topics, get_topics_with_reader from pyserini.util import download_url, get_cache_home -from pyserini.external_query_info import KILT_QUERY_INFO -from urllib.error import HTTPError, URLError @unique diff --git a/pyserini/search/__init__.py b/pyserini/search/__init__.py index c07814b1a..6d4cd4c10 100644 --- a/pyserini/search/__init__.py +++ b/pyserini/search/__init__.py @@ -14,50 +14,4 @@ # limitations under the License. # -from ._base import JQuery, JQueryGenerator, JDisjunctionMaxQueryGenerator, get_topics,\ - get_topics_with_reader, get_qrels_file, get_qrels -from .lucene import JScoredDoc, LuceneSimilarities, LuceneFusionSearcher, LuceneSearcher -from .lucene import JScoredDoc, LuceneImpactSearcher -from ._deprecated import SimpleSearcher, ImpactSearcher, SimpleFusionSearcher - -from .faiss import DenseSearchResult, PRFDenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \ - DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, AutoQueryEncoder, ClipQueryEncoder -from .faiss import AnceEncoder -from .faiss import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf -from .faiss import OpenAIQueryEncoder - - -__all__ = ['JQuery', - 'LuceneSimilarities', - 'LuceneFusionSearcher', - 'LuceneSearcher', - 'JScoredDoc', - 'LuceneImpactSearcher', - 'JScoredDoc', - 'JDisjunctionMaxQueryGenerator', - 'JQueryGenerator', - 'get_topics', - 'get_topics_with_reader', - 'get_qrels_file', - 'get_qrels', - 'SimpleSearcher', - 'ImpactSearcher', - 'SimpleFusionSearcher', - 'DenseSearchResult', - 'PRFDenseSearchResult', - 'FaissSearcher', - 'BinaryDenseSearcher', - 'QueryEncoder', - 'DprQueryEncoder', - 'BprQueryEncoder', - 'DkrrDprQueryEncoder', - 'TctColBertQueryEncoder', - 'AnceEncoder', - 'AnceQueryEncoder', - 'AggretrieverQueryEncoder', - 'OpenAIQueryEncoder', - 'AutoQueryEncoder', - 'DenseVectorAveragePrf', - 'DenseVectorRocchioPrf', - 'DenseVectorAncePrf'] - +from ._base import get_topics, get_topics_with_reader, get_qrels_file, get_qrels diff --git a/pyserini/search/__main__.py b/pyserini/search/__main__.py index 57e4826ce..c40005af1 100644 --- a/pyserini/search/__main__.py +++ b/pyserini/search/__main__.py @@ -14,8 +14,8 @@ # limitations under the License. # -import sys import os +import sys print('WARNING: directly calling pyserini.search is deprecated, please use pyserini.search.lucene instead') args = " ".join(sys.argv[1:]) diff --git a/pyserini/search/_base.py b/pyserini/search/_base.py index 58ae840e3..c7e699752 100644 --- a/pyserini/search/_base.py +++ b/pyserini/search/_base.py @@ -28,7 +28,6 @@ logging.basicConfig(level=logging.WARNING, format='\n%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Wrappers around Lucene classes -JQuery = autoclass('org.apache.lucene.search.Query') JPath = autoclass('java.nio.file.Path') # Wrappers around Anserini classes @@ -36,10 +35,6 @@ JRelevanceJudgments = autoclass('io.anserini.eval.RelevanceJudgments') JTopicReader = autoclass('io.anserini.search.topicreader.TopicReader') JTopics = autoclass('io.anserini.search.topicreader.Topics') -JQueryGenerator = autoclass('io.anserini.search.query.QueryGenerator') -JBagOfWordsQueryGenerator = autoclass('io.anserini.search.query.BagOfWordsQueryGenerator') -JDisjunctionMaxQueryGenerator = autoclass('io.anserini.search.query.DisjunctionMaxQueryGenerator') -JCovid19QueryGenerator = autoclass('io.anserini.search.query.Covid19QueryGenerator') # Function to safely get attributes from a class, returns None if not found diff --git a/pyserini/search/_deprecated.py b/pyserini/search/_deprecated.py index 0d877b086..894c058dd 100644 --- a/pyserini/search/_deprecated.py +++ b/pyserini/search/_deprecated.py @@ -14,7 +14,8 @@ # limitations under the License. # -from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher, LuceneFusionSearcher +from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher +from pyserini.search.lucene._searcher import LuceneFusionSearcher class SimpleSearcher(LuceneSearcher): diff --git a/pyserini/search/faiss/__init__.py b/pyserini/search/faiss/__init__.py index 25adcfcc1..a95c54181 100644 --- a/pyserini/search/faiss/__init__.py +++ b/pyserini/search/faiss/__init__.py @@ -14,14 +14,10 @@ # limitations under the License. # -from ._searcher import DenseSearchResult, PRFDenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \ - DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, OpenAIQueryEncoder, \ - AutoQueryEncoder, ClipQueryEncoder +from ._searcher import FaissSearcher, DenseSearchResult -from ._model import AnceEncoder -from._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf - -__all__ = ['DenseSearchResult', 'PRFDenseSearchResult', 'FaissSearcher', 'BinaryDenseSearcher', 'QueryEncoder', - 'DprQueryEncoder', 'BprQueryEncoder', 'DkrrDprQueryEncoder', 'TctColBertQueryEncoder', 'AnceEncoder', - 'AnceQueryEncoder', 'AggretrieverQueryEncoder', 'AutoQueryEncoder', 'DenseVectorAveragePrf', 'DenseVectorRocchioPrf', 'DenseVectorAncePrf', - 'OpenAIQueryEncoder', 'ClipQueryEncoder'] +# from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf, PRFDenseSearchResult +# from ._searcher import DenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \ +# DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, \ +# AggretrieverQueryEncoder, OpenAIQueryEncoder, \ +# AutoQueryEncoder, ClipQueryEncoder diff --git a/pyserini/search/faiss/__main__.py b/pyserini/search/faiss/__main__.py index 58a2a5de5..d067038ee 100644 --- a/pyserini/search/faiss/__main__.py +++ b/pyserini/search/faiss/__main__.py @@ -16,20 +16,20 @@ import argparse import os -from typing import OrderedDict +import numpy as np from tqdm import tqdm -from pyserini.search import FaissSearcher, BinaryDenseSearcher, TctColBertQueryEncoder, QueryEncoder, \ - DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, DenseVectorAveragePrf, \ - DenseVectorRocchioPrf, DenseVectorAncePrf, OpenAIQueryEncoder, ClipQueryEncoder - -from pyserini.encode import PcaEncoder, CosDprQueryEncoder, AutoQueryEncoder -from pyserini.query_iterator import get_query_iterator, TopicsFormat +from pyserini.encode import CosDprQueryEncoder +from pyserini.encode._pca import PcaEncoder from pyserini.output_writer import get_output_writer, OutputFormat +from pyserini.query_iterator import get_query_iterator, TopicsFormat +from pyserini.search.faiss._searcher import (AutoQueryEncoder, AggretrieverQueryEncoder, OpenAIQueryEncoder, + QueryEncoder, AnceQueryEncoder, BinaryDenseSearcher, BprQueryEncoder, + DprQueryEncoder, DkrrDprQueryEncoder, ClipQueryEncoder, TctColBertQueryEncoder) from pyserini.search.lucene import LuceneSearcher - -# from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf +from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf +from ._searcher import FaissSearcher # Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." # https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial diff --git a/pyserini/search/faiss/_prf.py b/pyserini/search/faiss/_prf.py index 68167318d..3efbd48f7 100644 --- a/pyserini/search/faiss/_prf.py +++ b/pyserini/search/faiss/_prf.py @@ -1,8 +1,34 @@ -import numpy as np +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from dataclasses import dataclass from typing import List, Dict -from pyserini.search.faiss import PRFDenseSearchResult, AnceQueryEncoder + +import numpy as np + +from pyserini.encode import AnceQueryEncoder from pyserini.search.lucene import LuceneSearcher -import json + + +@dataclass +class PRFDenseSearchResult: + docid: str + score: float + vectors: [float] class DenseVectorPrf: @@ -17,7 +43,6 @@ def get_batch_prf_q_emb(self, **kwargs): class DenseVectorAveragePrf(DenseVectorPrf): - def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): """Perform Average PRF with Dense Vectors diff --git a/pyserini/search/faiss/_searcher.py b/pyserini/search/faiss/_searcher.py index db730da3b..f24e59932 100644 --- a/pyserini/search/faiss/_searcher.py +++ b/pyserini/search/faiss/_searcher.py @@ -24,27 +24,26 @@ from typing import Dict, List, Union, Optional, Tuple import numpy as np -import pandas as pd import openai +import pandas as pd import tiktoken - +import torch from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer, BertTokenizerFast, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaTokenizer) from transformers.file_utils import is_faiss_available, requires_backends +from pyserini.encode._clip import ClipEncoder +from pyserini.index import Document +from pyserini.search.lucene import LuceneSearcher from pyserini.util import (download_encoded_queries, download_prebuilt_index, get_dense_indexes_info, get_sparse_index) -from pyserini.search.lucene import LuceneSearcher -from pyserini.index import Document - from ._model import AnceEncoder -import torch - -from ...encode import PcaEncoder, CosDprQueryEncoder, ClipEncoder from ...encode._aggretriever import BERTAggretrieverEncoder, DistlBERTAggretrieverEncoder -if is_faiss_available(): - import faiss +from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf, PRFDenseSearchResult + +#if is_faiss_available(): +import faiss class QueryEncoder: @@ -421,13 +420,6 @@ class DenseSearchResult: score: float -@dataclass -class PRFDenseSearchResult: - docid: str - score: float - vectors: [float] - - class FaissSearcher: """Simple Searcher for dense representation diff --git a/pyserini/search/hybrid/__init__.py b/pyserini/search/hybrid/__init__.py index 61ca8b0c8..ed9f468b2 100644 --- a/pyserini/search/hybrid/__init__.py +++ b/pyserini/search/hybrid/__init__.py @@ -15,5 +15,3 @@ # from ._searcher import HybridSearcher - -__all__ = ['HybridSearcher'] \ No newline at end of file diff --git a/pyserini/search/hybrid/__main__.py b/pyserini/search/hybrid/__main__.py index 873e318ca..592f36488 100644 --- a/pyserini/search/hybrid/__main__.py +++ b/pyserini/search/hybrid/__main__.py @@ -20,13 +20,12 @@ from tqdm import tqdm -from pyserini.search.faiss import FaissSearcher -from pyserini.query_iterator import get_query_iterator, TopicsFormat from pyserini.output_writer import get_output_writer, OutputFormat -from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher -from pyserini.search.hybrid import HybridSearcher - +from pyserini.query_iterator import get_query_iterator, TopicsFormat +from pyserini.search.faiss import FaissSearcher from pyserini.search.faiss.__main__ import define_dsearch_args, init_query_encoder +from pyserini.search.hybrid import HybridSearcher +from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher from pyserini.search.lucene.__main__ import define_search_args, set_bm25_parameters # Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." diff --git a/pyserini/search/hybrid/_searcher.py b/pyserini/search/hybrid/_searcher.py index 0817f6c85..10e4c21e9 100644 --- a/pyserini/search/hybrid/_searcher.py +++ b/pyserini/search/hybrid/_searcher.py @@ -19,8 +19,9 @@ """ from typing import List, Dict -from pyserini.search.lucene import LuceneSearcher + from pyserini.search.faiss import FaissSearcher, DenseSearchResult +from pyserini.search.lucene import LuceneSearcher class HybridSearcher: diff --git a/pyserini/search/lucene/__init__.py b/pyserini/search/lucene/__init__.py index dd570c557..928e72536 100644 --- a/pyserini/search/lucene/__init__.py +++ b/pyserini/search/lucene/__init__.py @@ -14,17 +14,16 @@ # limitations under the License. # -from ._geo_searcher import LuceneGeoSearcher -from ._impact_searcher import JScoredDoc, LuceneImpactSearcher, SlimSearcher -from ._searcher import JScoredDoc, LuceneSimilarities, LuceneFusionSearcher, LuceneSearcher -from ._hnsw_searcher import LuceneHnswDenseSearcher, LuceneFlatDenseSearcher +# We want to load the Java bindings first. +from pyserini.pyclass import autoclass + +JQuery = autoclass('org.apache.lucene.search.Query') +JScoredDoc = autoclass('io.anserini.search.ScoredDoc') +JQueryGenerator = autoclass('io.anserini.search.query.QueryGenerator') +JBagOfWordsQueryGenerator = autoclass('io.anserini.search.query.BagOfWordsQueryGenerator') +JDisjunctionMaxQueryGenerator = autoclass('io.anserini.search.query.DisjunctionMaxQueryGenerator') +JCovid19QueryGenerator = autoclass('io.anserini.search.query.Covid19QueryGenerator') -__all__ = ['JScoredDoc', - 'LuceneFusionSearcher', - 'LuceneGeoSearcher', - 'LuceneImpactSearcher', - 'LuceneSearcher', - 'LuceneHnswDenseSearcher', - 'LuceneFlatDenseSearcher', - 'SlimSearcher', - 'LuceneSimilarities'] +from ._impact_searcher import LuceneImpactSearcher, SlimSearcher +from ._searcher import LuceneSearcher, LuceneFusionSearcher, LuceneSimilarities +from ._hnsw_searcher import LuceneHnswDenseSearcher, LuceneFlatDenseSearcher diff --git a/pyserini/search/lucene/__main__.py b/pyserini/search/lucene/__main__.py index 9c7b8507f..ff68531e0 100644 --- a/pyserini/search/lucene/__main__.py +++ b/pyserini/search/lucene/__main__.py @@ -23,8 +23,10 @@ from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer from pyserini.output_writer import OutputFormat, get_output_writer from pyserini.query_iterator import get_query_iterator, TopicsFormat -from pyserini.search import JDisjunctionMaxQueryGenerator -from . import LuceneImpactSearcher, LuceneSearcher, SlimSearcher, LuceneHnswDenseSearcher, LuceneFlatDenseSearcher +from pyserini.search.lucene import JDisjunctionMaxQueryGenerator +from ._hnsw_searcher import LuceneHnswDenseSearcher, LuceneFlatDenseSearcher +from ._impact_searcher import LuceneImpactSearcher, SlimSearcher +from ._searcher import LuceneSearcher from .reranker import ClassifierType, PseudoRelevanceClassifierReranker diff --git a/pyserini/search/lucene/_geo_searcher.py b/pyserini/search/lucene/_geo_searcher.py index a5fb0b12d..ad67accc7 100644 --- a/pyserini/search/lucene/_geo_searcher.py +++ b/pyserini/search/lucene/_geo_searcher.py @@ -23,8 +23,7 @@ from typing import List from pyserini.pyclass import autoclass -from pyserini.search import JQuery - +from pyserini.search.lucene import JQuery logger = logging.getLogger(__name__) diff --git a/pyserini/search/lucene/_impact_searcher.py b/pyserini/search/lucene/_impact_searcher.py index 85cc6bfe9..dc36515c2 100644 --- a/pyserini/search/lucene/_impact_searcher.py +++ b/pyserini/search/lucene/_impact_searcher.py @@ -22,24 +22,24 @@ import logging import os import pickle -from tqdm import tqdm -from typing import Dict, List, Optional, Union from collections import namedtuple +from typing import Dict, List, Optional, Union import numpy as np import scipy +from tqdm import tqdm -from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, \ - CachedDataQueryEncoder, SpladeQueryEncoder, SlimQueryEncoder -from pyserini.index import Document +from pyserini.encode import QueryEncoder, CachedDataQueryEncoder, SlimQueryEncoder, SpladeQueryEncoder, \ + TokFreqQueryEncoder, UniCoilQueryEncoder +from pyserini.index.lucene import Document, LuceneIndexReader from pyserini.pyclass import autoclass, JFloat, JInt, JArrayList, JHashMap +from pyserini.search.lucene import JScoredDoc from pyserini.util import download_prebuilt_index, download_encoded_corpus logger = logging.getLogger(__name__) # Wrappers around Anserini classes JSimpleImpactSearcher = autoclass('io.anserini.search.SimpleImpactSearcher') -JScoredDoc = autoclass('io.anserini.search.ScoredDoc') class LuceneImpactSearcher: @@ -376,8 +376,7 @@ def _init_query_encoder_from_str(query_encoder): @staticmethod def _compute_idf(index_path): - from pyserini.index.lucene import IndexReader - index_reader = IndexReader(index_path) + index_reader = LuceneIndexReader(index_path) tokens = [] dfs = [] for term in index_reader.terms(): diff --git a/pyserini/search/lucene/_searcher.py b/pyserini/search/lucene/_searcher.py index f824baf70..8eb378db6 100644 --- a/pyserini/search/lucene/_searcher.py +++ b/pyserini/search/lucene/_searcher.py @@ -23,9 +23,9 @@ from typing import Dict, List, Optional, Union from pyserini.fusion import FusionMethod, reciprocal_rank_fusion -from pyserini.index import Document, IndexReader +from pyserini.index.lucene import Document, LuceneIndexReader from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap -from pyserini.search import JQuery, JQueryGenerator +from pyserini.search.lucene import JQuery, JQueryGenerator, JScoredDoc from pyserini.trectools import TrecRun from pyserini.util import download_prebuilt_index, get_sparse_indexes_info @@ -34,7 +34,6 @@ # Wrappers around Anserini classes JSimpleSearcher = autoclass('io.anserini.search.SimpleSearcher') -JScoredDoc = autoclass('io.anserini.search.ScoredDoc') class LuceneSearcher: @@ -78,10 +77,10 @@ def from_prebuilt_index(cls, prebuilt_index_name: str, verbose=False): print(str(e)) return None - # Currently, the only way to validate stats is to create a separate IndexReader, because there is no method + # Currently, the only way to validate stats is to create a separate LuceneIndexReader, because there is no method # to obtain the underlying reader of a SimpleSearcher; see https://github.com/castorini/anserini/issues/2013 - index_reader = IndexReader(index_dir) - # This is janky as we're created a separate IndexReader for the sole purpose of validating index stats. + index_reader = LuceneIndexReader(index_dir) + # This is janky as we're created a separate LuceneIndexReader for the sole purpose of validating index stats. index_reader.validate(prebuilt_index_name, verbose=verbose) if verbose: diff --git a/pyserini/search/lucene/irst/__init__.py b/pyserini/search/lucene/irst/__init__.py index 463ef3d62..c27321c85 100644 --- a/pyserini/search/lucene/irst/__init__.py +++ b/pyserini/search/lucene/irst/__init__.py @@ -15,4 +15,3 @@ # from ._searcher import LuceneIrstSearcher -__all__ = ['LuceneIrstSearcher'] diff --git a/pyserini/search/lucene/irst/__main__.py b/pyserini/search/lucene/irst/__main__.py index 27ed16b4a..c294a2ca8 100644 --- a/pyserini/search/lucene/irst/__main__.py +++ b/pyserini/search/lucene/irst/__main__.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import argparse from typing import List -from tqdm import tqdm + from transformers import AutoTokenizer + from pyserini.search import get_topics from pyserini.search.lucene.irst import LuceneIrstSearcher diff --git a/pyserini/search/lucene/irst/_searcher.py b/pyserini/search/lucene/irst/_searcher.py index 85a69ad2c..299647626 100644 --- a/pyserini/search/lucene/irst/_searcher.py +++ b/pyserini/search/lucene/irst/_searcher.py @@ -30,10 +30,10 @@ from transformers import AutoTokenizer +from pyserini.prebuilt_index_info import TF_INDEX_INFO from pyserini.pyclass import autoclass from pyserini.search.lucene import LuceneSearcher from pyserini.util import download_prebuilt_index, get_cache_home, download_url, download_and_unpack_index -from pyserini.prebuilt_index_info import TF_INDEX_INFO # Wrappers around Anserini classes JQuery = autoclass('org.apache.lucene.search.Query') diff --git a/pyserini/search/lucene/ltr/__main__.py b/pyserini/search/lucene/ltr/__main__.py index 4bef7c9eb..462407c93 100644 --- a/pyserini/search/lucene/ltr/__main__.py +++ b/pyserini/search/lucene/ltr/__main__.py @@ -15,15 +15,16 @@ # import argparse +from collections import defaultdict + import numpy as np import pandas as pd - from tqdm import tqdm -from collections import defaultdict from transformers import AutoTokenizer -from pyserini.search.lucene.ltr import * -from pyserini.search.lucene import LuceneSearcher + from pyserini.analysis import Analyzer, get_lucene_analyzer +from pyserini.search.lucene import LuceneSearcher +from pyserini.search.lucene.ltr import * """ Running prediction on candidates diff --git a/pyserini/search/lucene/ltr/_base.py b/pyserini/search/lucene/ltr/_base.py index 879b897ec..62f44055c 100644 --- a/pyserini/search/lucene/ltr/_base.py +++ b/pyserini/search/lucene/ltr/_base.py @@ -14,12 +14,15 @@ # limitations under the License. # -from pyserini.pyclass import autoclass import json +import re + import numpy as np import pandas as pd import spacy -import re + +from pyserini.pyclass import autoclass + class Feature: def name(self): diff --git a/pyserini/search/lucene/ltr/_search_msmarco.py b/pyserini/search/lucene/ltr/_search_msmarco.py index 24f9ea9aa..7b2a5f4f2 100644 --- a/pyserini/search/lucene/ltr/_search_msmarco.py +++ b/pyserini/search/lucene/ltr/_search_msmarco.py @@ -21,16 +21,16 @@ import logging import multiprocessing -import time import os -from tqdm import tqdm import pickle -from pyserini.index.lucene import IndexReader -from pyserini.search.lucene import LuceneSearcher -from pyserini.util import get_cache_home +import time -from pyserini.search.lucene.ltr._base import * +from tqdm import tqdm +from pyserini.index.lucene import LuceneIndexReader +from pyserini.search.lucene import LuceneSearcher +from pyserini.search.lucene.ltr._base import * +from pyserini.util import get_cache_home logger = logging.getLogger(__name__) @@ -46,10 +46,10 @@ def __init__(self, model: str, ibm_model:str, index:str, data: str, prebuilt: bo index_path = os.path.join(index_directory, 'lucene-inverted.msmarco-v1-passage.ltr.20210519.e25e33f.5da425ca44d2e3e5c38a7f564f13ad23') else: index_path = os.path.join(index_directory, 'lucene-inverted.msmarco-v1-doc-segmented.ltr.20211031.33e4151.86f108d8441b6845f8caf1208dd7ac7a') - self.index_reader = IndexReader.from_prebuilt_index(index) + self.index_reader = LuceneIndexReader.from_prebuilt_index(index) else: index_path = index - self.index_reader = IndexReader(index) + self.index_reader = LuceneIndexReader(index) self.fe = FeatureExtractor(index_path, max(multiprocessing.cpu_count()//2, 1)) self.data = data diff --git a/pyserini/search/lucene/querybuilder.py b/pyserini/search/lucene/querybuilder.py index 7627121c2..dec1cc8b8 100644 --- a/pyserini/search/lucene/querybuilder.py +++ b/pyserini/search/lucene/querybuilder.py @@ -17,6 +17,7 @@ """ This module provides Pyserini's Python interface query building for Anserini. """ + import logging from enum import Enum diff --git a/pyserini/search/lucene/reranker.py b/pyserini/search/lucene/reranker.py index e5fa17759..55747375f 100644 --- a/pyserini/search/lucene/reranker.py +++ b/pyserini/search/lucene/reranker.py @@ -16,11 +16,10 @@ import enum import importlib -import os -import uuid +from typing import List + from sklearn.linear_model import LogisticRegression from sklearn.svm import SVC -from typing import List class ClassifierType(enum.Enum): diff --git a/pyserini/search/nmslib/__init__.py b/pyserini/search/nmslib/__init__.py index e3188206e..ab7ea4ea9 100644 --- a/pyserini/search/nmslib/__init__.py +++ b/pyserini/search/nmslib/__init__.py @@ -15,5 +15,3 @@ # from ._searcher import SearchResult, NmslibSearcher - -__all__ = ['SearchResult', 'NmslibSearcher'] diff --git a/pyserini/search/nmslib/__main__.py b/pyserini/search/nmslib/__main__.py index 581fafa91..6d62db38f 100644 --- a/pyserini/search/nmslib/__main__.py +++ b/pyserini/search/nmslib/__main__.py @@ -17,10 +17,11 @@ import argparse import json import time + from tqdm import tqdm -from ._searcher import NmslibSearcher from pyserini.output_writer import get_output_writer, OutputFormat, tie_breaker +from ._searcher import NmslibSearcher if __name__ == '__main__': parser = argparse.ArgumentParser(description='Search a nmslib index.') diff --git a/pyserini/search/nmslib/_searcher.py b/pyserini/search/nmslib/_searcher.py index e0db6a247..b122164b3 100644 --- a/pyserini/search/nmslib/_searcher.py +++ b/pyserini/search/nmslib/_searcher.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import json import os from dataclasses import dataclass diff --git a/pyserini/server/AnseriniApplication/__main__.py b/pyserini/server/AnseriniApplication/__main__.py index 5e7d63eca..c7d3c5a17 100644 --- a/pyserini/server/AnseriniApplication/__main__.py +++ b/pyserini/server/AnseriniApplication/__main__.py @@ -14,8 +14,8 @@ # limitations under the License. # -import sys import signal +import sys from pyserini.pyclass import autoclass diff --git a/pyserini/trectools/_base.py b/pyserini/trectools/_base.py index 603cbb672..0271df8f7 100644 --- a/pyserini/trectools/_base.py +++ b/pyserini/trectools/_base.py @@ -15,14 +15,14 @@ # import itertools -import numpy as np -import pandas as pd - from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from enum import Enum from typing import List, Set, Tuple +import numpy as np +import pandas as pd + class AggregationMethod(Enum): SUM = 'sum' diff --git a/pyserini/util.py b/pyserini/util.py index 1649b3585..142f73386 100644 --- a/pyserini/util.py +++ b/pyserini/util.py @@ -15,24 +15,23 @@ # import hashlib +import logging import os import re import shutil import tarfile -import logging from urllib.error import HTTPError, URLError from urllib.request import urlretrieve import pandas as pd from tqdm import tqdm -from pyserini.encoded_query_info import QUERY_INFO from pyserini.encoded_corpus_info import CORPUS_INFO +from pyserini.encoded_query_info import QUERY_INFO from pyserini.evaluate_script_info import EVALUATION_INFO from pyserini.prebuilt_index_info import TF_INDEX_INFO, IMPACT_INDEX_INFO, \ LUCENE_HNSW_INDEX_INFO, LUCENE_FLAT_INDEX_INFO, FAISS_INDEX_INFO - logger = logging.getLogger(__name__) diff --git a/pyserini/vectorizer/__init__.py b/pyserini/vectorizer/__init__.py index dafc62521..e15a5bb9f 100644 --- a/pyserini/vectorizer/__init__.py +++ b/pyserini/vectorizer/__init__.py @@ -15,5 +15,3 @@ # from ._base import BM25Vectorizer, TfidfVectorizer - -__all__ = ['BM25Vectorizer', 'TfidfVectorizer'] diff --git a/pyserini/vectorizer/_base.py b/pyserini/vectorizer/_base.py index 255656c38..4be87743c 100644 --- a/pyserini/vectorizer/_base.py +++ b/pyserini/vectorizer/_base.py @@ -16,13 +16,14 @@ import math from typing import List, Optional -from sklearn.preprocessing import normalize from scipy.sparse import csr_matrix +from sklearn.preprocessing import normalize +from tqdm import tqdm -from pyserini import index, search from pyserini.analysis import Analyzer, get_lucene_analyzer -from tqdm import tqdm +from pyserini.index.lucene import LuceneIndexReader +from pyserini.search.lucene import LuceneSearcher class Vectorizer: @@ -41,8 +42,8 @@ class Vectorizer: def __init__(self, lucene_index_path: str, min_df: int = 1, verbose: bool = False): self.min_df: int = min_df self.verbose: bool = verbose - self.index_reader = index.IndexReader(lucene_index_path) - self.searcher = search.LuceneSearcher(lucene_index_path) + self.index_reader = LuceneIndexReader(lucene_index_path) + self.searcher = LuceneSearcher(lucene_index_path) self.num_docs: int = self.searcher.num_docs self.stats = self.index_reader.stats() self.analyzer = Analyzer(get_lucene_analyzer()) diff --git a/requirements.txt b/requirements.txt index 7738fd42c..b1ac808fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,24 @@ +tqdm +pyyaml +requests Cython>=0.29.21 numpy>=1.18.1 pandas>=1.4.0 -pyjnius>=1.4.0 +pyjnius>=1.6.0 scikit-learn>=0.22.1 scipy>=1.4.1 -tqdm transformers>=4.6.0 -sentencepiece>=0.1.95 -nmslib>=2.0.6 +torch>=2.0.0 onnxruntime>=1.8.1 -lightgbm>=3.3.2 -spacy>=3.2.1 -pyyaml openai>=1.0.0 +sentencepiece>=0.2 tiktoken>=0.4.0 +# below are going to be "optional", eventually +faiss>1.7.0 +flask>3.0 +nmslib>=2.0.6 +lightgbm>=3.3.2 +spacy>=3.2.1 pyarrow>=15.0.0 pillow>=10.2.0 pybind11>=2.11.0 diff --git a/tests-optional/test_encoder.py b/tests-optional/test_encoder.py new file mode 100644 index 000000000..b24b1477b --- /dev/null +++ b/tests-optional/test_encoder.py @@ -0,0 +1,343 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import pathlib as pl +import shutil +import tarfile +import unittest +from random import randint +from urllib.request import urlretrieve + +import faiss + +from pyserini.encode import TctColBertDocumentEncoder, DprDocumentEncoder, UniCoilDocumentEncoder +from pyserini.encode._clip import ClipDocumentEncoder +from pyserini.search.lucene import LuceneImpactSearcher + + +## We need to de-dup wrt tests/test_encoder + +class TestEncode(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.docids = [] + cls.texts = [] + cls.test_file = 'tests/resources/simple_cacm_corpus.json' + + with open(cls.test_file) as f: + for line in f: + line = json.loads(line) + cls.docids.append(line['id']) + cls.texts.append(line['contents']) + + # LuceneImpactSearcher requires a pre-built index to be initialized + r = randint(0, 10000000) + cls.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz' + cls.tarball_name = f'lucene-index.cacm-{r}.tar.gz' + cls.index_dir = f'index-{r}/' + + urlretrieve(cls.collection_url, cls.tarball_name) + + tarball = tarfile.open(cls.tarball_name) + tarball.extractall(cls.index_dir) + tarball.close() + + @staticmethod + def assertIsFile(path): + if not pl.Path(path).resolve().is_file(): + raise AssertionError("File does not exist: %s" % str(path)) + + def test_dpr_encoder(self): + encoder = DprDocumentEncoder('facebook/dpr-ctx_encoder-multiset-base', device='cpu') + vectors = encoder.encode(self.texts[:3]) + self.assertAlmostEqual(vectors[0][0], -0.59793323, places=4) + self.assertAlmostEqual(vectors[0][-1], -0.13036962, places=4) + self.assertAlmostEqual(vectors[2][0], -0.3044764, places=4) + self.assertAlmostEqual(vectors[2][-1], 0.1516793, places=4) + + def test_tct_colbert_encoder(self): + encoder = TctColBertDocumentEncoder('castorini/tct_colbert-msmarco', device='cpu') + vectors = encoder.encode(self.texts[:3]) + self.assertAlmostEqual(vectors[0][0], -0.01649557, places=4) + self.assertAlmostEqual(vectors[0][-1], -0.05648308, places=4) + self.assertAlmostEqual(vectors[2][0], -0.10293338, places=4) + self.assertAlmostEqual(vectors[2][-1], 0.05549275, places=4) + + def test_unicoil_encoder(self): + encoder = UniCoilDocumentEncoder('castorini/unicoil-msmarco-passage', device='cpu') + vectors = encoder.encode(self.texts[:3]) + self.assertAlmostEqual(vectors[0]['generation'], 2.2441017627716064, places=4) + self.assertAlmostEqual(vectors[0]['normal'], 2.4618067741394043, places=4) + self.assertAlmostEqual(vectors[2]['rounding'], 3.9474332332611084, places=4) + self.assertAlmostEqual(vectors[2]['commercial'], 3.288801670074463, places=4) + + def test_clip_encoder(self): + encoder = ClipDocumentEncoder('openai/clip-vit-base-patch32', device='cpu') + vectors = encoder.encode(self.texts[:3]) + self.assertAlmostEqual(vectors[0][0], 0.1933609, places=4) + self.assertAlmostEqual(vectors[0][-1], -0.21501173, places=4) + self.assertAlmostEqual(vectors[2][0], 0.06461975, places=4) + self.assertAlmostEqual(vectors[2][-1], 0.35396004, places=4) + + def test_tct_colbert_v2_encoder_cmd(self): + index_dir = 'temp_index' + cmd = f'python -m pyserini.encode \ + input --corpus {self.test_file} \ + --fields text \ + output --embeddings {index_dir} \ + encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ + --fields text \ + --batch 1 \ + --device cpu' + status = os.system(cmd) + self.assertEqual(status, 0) + + embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') + self.assertIsFile(embedding_json_fn) + + with open(embedding_json_fn) as f: + embeddings = [json.loads(line) for line in f] + + self.assertListEqual([entry["id"] for entry in embeddings], self.docids) + self.assertListEqual( + [entry["contents"] for entry in embeddings], + [entry.strip() for entry in self.texts], + ) + + self.assertAlmostEqual(embeddings[0]['vector'][0], 0.12679848074913025, places=4) + self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.0037349488120526075, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][0], 0.03678430616855621, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.13209162652492523, places=4) + + shutil.rmtree(index_dir) + + def test_tct_colbert_v2_encoder_cmd_shard(self): + cleanup_list = [] + for shard_i in range(2): + index_dir = f'temp_index-{shard_i}' + cleanup_list.append(index_dir) + cmd = f'python -m pyserini.encode \ + input --corpus {self.test_file} \ + --fields text \ + --shard-id {shard_i} \ + --shard-num 2 \ + output --embeddings {index_dir} \ + --to-faiss \ + encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ + --fields text \ + --batch 1 \ + --device cpu' + status = os.system(cmd) + self.assertEqual(status, 0) + self.assertIsFile(os.path.join(index_dir, 'docid')) + self.assertIsFile(os.path.join(index_dir, 'index')) + + cmd = f'python -m pyserini.index.merge_faiss_indexes --prefix temp_index- --shard-num 2' + index_dir = 'temp_index-full' + cleanup_list.append(index_dir) + docid_fn = os.path.join(index_dir, 'docid') + index_fn = os.path.join(index_dir, 'index') + + status = os.system(cmd) + self.assertEqual(status, 0) + self.assertIsFile(docid_fn) + self.assertIsFile(index_fn) + + index = faiss.read_index(index_fn) + vectors = index.reconstruct_n(0, index.ntotal) + + with open(docid_fn) as f: + self.assertListEqual([docid.strip() for docid in f], self.docids) + + self.assertAlmostEqual(vectors[0][0], 0.12679848074913025, places=4) + self.assertAlmostEqual(vectors[0][-1], -0.0037349488120526075, places=4) + self.assertAlmostEqual(vectors[2][0], 0.03678430616855621, places=4) + self.assertAlmostEqual(vectors[2][-1], 0.13209162652492523, places=4) + + for index_dir in cleanup_list: + shutil.rmtree(index_dir) + + def test_aggretriever_distilbert_encoder_cmd(self): + index_dir = 'temp_index' + cmd = f'python -m pyserini.encode \ + input --corpus {self.test_file} \ + --fields text \ + output --embeddings {index_dir} \ + encoder --encoder castorini/aggretriever-distilbert \ + --fields text \ + --batch 1 \ + --device cpu' + status = os.system(cmd) + self.assertEqual(status, 0) + + embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') + self.assertIsFile(embedding_json_fn) + + with open(embedding_json_fn) as f: + embeddings = [json.loads(line) for line in f] + + self.assertListEqual([entry["id"] for entry in embeddings], self.docids) + self.assertListEqual( + [entry["contents"] for entry in embeddings], + [entry.strip() for entry in self.texts], + ) + self.assertAlmostEqual(embeddings[0]['vector'][0], 0.14203716814517975, places=4) + self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.011851579882204533, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][0], 0.4780103862285614, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.0017992404755204916, places=4) + + shutil.rmtree(index_dir) + + def test_aggretriever_cocondenser_encoder_cmd(self): + index_dir = 'temp_index' + cmd = f'python -m pyserini.encode \ + input --corpus {self.test_file} \ + --fields text \ + output --embeddings {index_dir} \ + encoder --encoder castorini/aggretriever-cocondenser \ + --fields text \ + --batch 1 \ + --device cpu' + status = os.system(cmd) + self.assertEqual(status, 0) + + embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') + self.assertIsFile(embedding_json_fn) + + with open(embedding_json_fn) as f: + embeddings = [json.loads(line) for line in f] + + self.assertListEqual([entry["id"] for entry in embeddings], self.docids) + self.assertListEqual( + [entry["contents"] for entry in embeddings], + [entry.strip() for entry in self.texts], + ) + self.assertAlmostEqual(embeddings[0]['vector'][0], 0.4865410327911377, places=4) + self.assertAlmostEqual(embeddings[0]['vector'][-1], 0.006781343836337328, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][0], 0.32751473784446716, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.0014184381579980254, places=4) + + shutil.rmtree(index_dir) + + def test_onnx_encode_unicoil(self): + temp_object = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'SpladePlusPlusEnsembleDistil', encoder_type='onnx') + + # this function will never be called in _impact_searcher, here to check quantization correctness + results = temp_object.encode("here is a test") + self.assertEqual(results.get("here"), 156) + self.assertEqual(results.get("a"), 31) + self.assertEqual(results.get("test"), 149) + + temp_object.close() + del temp_object + + temp_object1 = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'naver/splade-cocondenser-ensembledistil') + + # this function will never be called in _impact_searcher, here to check quantization correctness + results = temp_object1.encode("here is a test") + self.assertEqual(results.get("here"), 156) + self.assertEqual(results.get("a"), 31) + self.assertEqual(results.get("test"), 149) + + temp_object1.close() + del temp_object1 + + def test_clip_encoder_cmd_text(self): + index_dir = 'temp_index' + cmd = f'python -m pyserini.encode \ + input --corpus {self.test_file} \ + --fields text \ + output --embeddings {index_dir} \ + encoder --encoder openai/clip-vit-base-patch32 \ + --fields text \ + --batch 1 --max-length 77 \ + --device cpu' + status = os.system(cmd) + self.assertEqual(status, 0) + + embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') + self.assertIsFile(embedding_json_fn) + + with open(embedding_json_fn) as f: + embeddings = [json.loads(line) for line in f] + + self.assertListEqual([entry["id"] for entry in embeddings], self.docids) + self.assertListEqual( + [entry["contents"] for entry in embeddings], + [entry.strip() for entry in self.texts], + ) + + self.assertAlmostEqual(embeddings[0]['vector'][0], 0.022726990282535553, places=4) + self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.02527175098657608, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][0], 0.00724585447460413, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.039689723402261734, places=4) + + shutil.rmtree(index_dir) + + def test_clip_encoder_cmd_image(self): + # special case setup for image data + docids = [] + texts = [] + test_file = 'tests/resources/sample_collection_jsonl_image/images.small.jsonl' + image_dir = pl.Path(test_file).parent + + with open(test_file) as f: + for line in f: + line = json.loads(line) + docids.append(line['id']) + texts.append(line['path']) + + index_dir = 'temp_index' + cmd = f'python -m pyserini.encode \ + input --corpus {test_file} \ + --fields path \ + output --embeddings {index_dir} \ + encoder --encoder openai/clip-vit-base-patch32 \ + --fields path \ + --batch 1 --multimodal --l2-norm \ + --device cpu' + status = os.system(cmd) + self.assertEqual(status, 0) + + embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') + self.assertIsFile(embedding_json_fn) + + with open(embedding_json_fn) as f: + embeddings = [json.loads(line) for line in f] + + self.assertListEqual([entry["id"] for entry in embeddings], docids) + self.assertListEqual( + [entry["contents"] for entry in embeddings], + [str(pl.Path(image_dir, entry.strip())) for entry in texts], + ) + + self.assertAlmostEqual(embeddings[0]['vector'][0], 0.003283643862232566, places=4) + self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.055951327085494995, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][0], 0.021012384444475174, places=4) + self.assertAlmostEqual(embeddings[2]['vector'][-1], -0.0011692788684740663, places=4) + + shutil.rmtree(index_dir) + + @classmethod + def tearDownClass(cls): + os.remove(cls.tarball_name) + shutil.rmtree(cls.index_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_hybrid_search.py b/tests-optional/test_hybrid_search.py similarity index 95% rename from tests/test_hybrid_search.py rename to tests-optional/test_hybrid_search.py index 5b5567b09..153e677fb 100644 --- a/tests/test_hybrid_search.py +++ b/tests-optional/test_hybrid_search.py @@ -17,9 +17,10 @@ import unittest from typing import List, Dict -from pyserini.search.lucene import LuceneSearcher -from pyserini.search.faiss import FaissSearcher, AutoQueryEncoder +from pyserini.search.faiss import FaissSearcher +from pyserini.search.faiss._searcher import AutoQueryEncoder from pyserini.search.hybrid import HybridSearcher +from pyserini.search.lucene import LuceneSearcher class TestHybridSearch(unittest.TestCase): diff --git a/tests/test_index_faiss.py b/tests-optional/test_index_faiss.py similarity index 100% rename from tests/test_index_faiss.py rename to tests-optional/test_index_faiss.py index 33b22c799..739bfa47a 100644 --- a/tests/test_index_faiss.py +++ b/tests-optional/test_index_faiss.py @@ -16,12 +16,12 @@ import json import os +import pathlib as pl import random - -import faiss import shutil import unittest -import pathlib as pl + +import faiss class TestIndexFaiss(unittest.TestCase): diff --git a/tests/test_load_encoded_queries.py b/tests-optional/test_load_encoded_queries.py similarity index 98% rename from tests/test_load_encoded_queries.py rename to tests-optional/test_load_encoded_queries.py index 721891e16..09cde0f28 100644 --- a/tests/test_load_encoded_queries.py +++ b/tests-optional/test_load_encoded_queries.py @@ -18,8 +18,8 @@ import unittest -from pyserini.search import QueryEncoder from pyserini.search import get_topics +from pyserini.search.faiss._searcher import QueryEncoder class TestLoadEncodedQueries(unittest.TestCase): diff --git a/tests/test_multimodal_search.py b/tests-optional/test_multimodal_search.py similarity index 97% rename from tests/test_multimodal_search.py rename to tests-optional/test_multimodal_search.py index cfd22f6c7..f1d7bf5e1 100644 --- a/tests/test_multimodal_search.py +++ b/tests-optional/test_multimodal_search.py @@ -15,13 +15,13 @@ # import os +import pathlib as pl import shutil import unittest -from typing import List, Dict - -from pyserini.search.faiss import FaissSearcher, ClipQueryEncoder -import pathlib as pl +from typing import List +from pyserini.search.faiss import FaissSearcher +from pyserini.search.faiss._searcher import ClipQueryEncoder class TestMultimodalSearch(unittest.TestCase): diff --git a/tests/test_nfcorpus.py b/tests-optional/test_nfcorpus_faiss.py similarity index 98% rename from tests/test_nfcorpus.py rename to tests-optional/test_nfcorpus_faiss.py index 9f0e629cc..1e8f19bd6 100644 --- a/tests/test_nfcorpus.py +++ b/tests-optional/test_nfcorpus_faiss.py @@ -30,6 +30,7 @@ def setUpClass(cls): cls.queries = 'tests/resources/nfcorpus-queries.tsv' cls.qrels = 'tests/resources/nfcorpus-qrels.tsv' + # TODO: Remove the Lucene part, just keep the Faiss part r = randint(0, 10000000) cls.dense_index_url = 'https://github.com/castorini/anserini-data/raw/master/NFCorpus/faiss.nfcorpus.contriever-msmacro.tar.gz' cls.dense_tarball_name = f'faiss.nfcorpus.contriever-msmacro-{r}.tar.gz' diff --git a/tests/test_analysis.py b/tests/test_analysis.py index e4c7ea8df..0d77effbe 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -22,7 +22,7 @@ from urllib.request import urlretrieve from pyserini.analysis import JAnalyzer, JAnalyzerUtils, Analyzer, get_lucene_analyzer -from pyserini.index.lucene import IndexReader +from pyserini.index.lucene import LuceneIndexReader from pyserini.search.lucene import LuceneSearcher @@ -40,7 +40,7 @@ def setUp(self): tarball.extractall(self.index_dir) tarball.close() self.searcher = LuceneSearcher(f'{self.index_dir}lucene9-index.cacm') - self.index_utils = IndexReader(f'{self.index_dir}lucene9-index.cacm') + self.index_utils = LuceneIndexReader(f'{self.index_dir}lucene9-index.cacm') def test_different_analyzers_are_different(self): self.searcher.set_analyzer(get_lucene_analyzer(stemming=False)) diff --git a/tests/test_analysis_lucene8.py b/tests/test_analysis_lucene8.py index 25fb59b9c..0efb160a4 100644 --- a/tests/test_analysis_lucene8.py +++ b/tests/test_analysis_lucene8.py @@ -22,7 +22,7 @@ from urllib.request import urlretrieve from pyserini.analysis import get_lucene_analyzer -from pyserini.index.lucene import IndexReader +from pyserini.index.lucene import LuceneIndexReader from pyserini.search.lucene import LuceneSearcher @@ -41,7 +41,7 @@ def setUp(self): tarball.extractall(self.index_dir) tarball.close() self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm') - self.index_utils = IndexReader(f'{self.index_dir}lucene-index.cacm') + self.index_utils = LuceneIndexReader(f'{self.index_dir}lucene-index.cacm') def test_different_analyzers_are_different(self): self.searcher.set_analyzer(get_lucene_analyzer(stemming=False)) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index 235bbebc2..852d75b37 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -23,9 +23,7 @@ from random import randint from urllib.request import urlretrieve -import faiss - -from pyserini.encode import TctColBertDocumentEncoder, DprDocumentEncoder, UniCoilDocumentEncoder, ClipDocumentEncoder +from pyserini.encode import TctColBertDocumentEncoder, DprDocumentEncoder, UniCoilDocumentEncoder from pyserini.search.lucene import LuceneImpactSearcher @@ -82,154 +80,6 @@ def test_unicoil_encoder(self): self.assertAlmostEqual(vectors[0]['normal'], 2.4618067741394043, places=4) self.assertAlmostEqual(vectors[2]['rounding'], 3.9474332332611084, places=4) self.assertAlmostEqual(vectors[2]['commercial'], 3.288801670074463, places=4) - - def test_clip_encoder(self): - encoder = ClipDocumentEncoder('openai/clip-vit-base-patch32', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0][0], 0.1933609, places=4) - self.assertAlmostEqual(vectors[0][-1], -0.21501173, places=4) - self.assertAlmostEqual(vectors[2][0], 0.06461975, places=4) - self.assertAlmostEqual(vectors[2][-1], 0.35396004, places=4) - - def test_tct_colbert_v2_encoder_cmd(self): - index_dir = 'temp_index' - cmd = f'python -m pyserini.encode \ - input --corpus {self.test_file} \ - --fields text \ - output --embeddings {index_dir} \ - encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ - --fields text \ - --batch 1 \ - --device cpu' - status = os.system(cmd) - self.assertEqual(status, 0) - - embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') - self.assertIsFile(embedding_json_fn) - - with open(embedding_json_fn) as f: - embeddings = [json.loads(line) for line in f] - - self.assertListEqual([entry["id"] for entry in embeddings], self.docids) - self.assertListEqual( - [entry["contents"] for entry in embeddings], - [entry.strip() for entry in self.texts], - ) - - self.assertAlmostEqual(embeddings[0]['vector'][0], 0.12679848074913025, places=4) - self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.0037349488120526075, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][0], 0.03678430616855621, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.13209162652492523, places=4) - - shutil.rmtree(index_dir) - - def test_tct_colbert_v2_encoder_cmd_shard(self): - cleanup_list = [] - for shard_i in range(2): - index_dir = f'temp_index-{shard_i}' - cleanup_list.append(index_dir) - cmd = f'python -m pyserini.encode \ - input --corpus {self.test_file} \ - --fields text \ - --shard-id {shard_i} \ - --shard-num 2 \ - output --embeddings {index_dir} \ - --to-faiss \ - encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ - --fields text \ - --batch 1 \ - --device cpu' - status = os.system(cmd) - self.assertEqual(status, 0) - self.assertIsFile(os.path.join(index_dir, 'docid')) - self.assertIsFile(os.path.join(index_dir, 'index')) - - cmd = f'python -m pyserini.index.merge_faiss_indexes --prefix temp_index- --shard-num 2' - index_dir = 'temp_index-full' - cleanup_list.append(index_dir) - docid_fn = os.path.join(index_dir, 'docid') - index_fn = os.path.join(index_dir, 'index') - - status = os.system(cmd) - self.assertEqual(status, 0) - self.assertIsFile(docid_fn) - self.assertIsFile(index_fn) - - index = faiss.read_index(index_fn) - vectors = index.reconstruct_n(0, index.ntotal) - - with open(docid_fn) as f: - self.assertListEqual([docid.strip() for docid in f], self.docids) - - self.assertAlmostEqual(vectors[0][0], 0.12679848074913025, places=4) - self.assertAlmostEqual(vectors[0][-1], -0.0037349488120526075, places=4) - self.assertAlmostEqual(vectors[2][0], 0.03678430616855621, places=4) - self.assertAlmostEqual(vectors[2][-1], 0.13209162652492523, places=4) - - for index_dir in cleanup_list: - shutil.rmtree(index_dir) - - def test_aggretriever_distilbert_encoder_cmd(self): - index_dir = 'temp_index' - cmd = f'python -m pyserini.encode \ - input --corpus {self.test_file} \ - --fields text \ - output --embeddings {index_dir} \ - encoder --encoder castorini/aggretriever-distilbert \ - --fields text \ - --batch 1 \ - --device cpu' - status = os.system(cmd) - self.assertEqual(status, 0) - - embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') - self.assertIsFile(embedding_json_fn) - - with open(embedding_json_fn) as f: - embeddings = [json.loads(line) for line in f] - - self.assertListEqual([entry["id"] for entry in embeddings], self.docids) - self.assertListEqual( - [entry["contents"] for entry in embeddings], - [entry.strip() for entry in self.texts], - ) - self.assertAlmostEqual(embeddings[0]['vector'][0], 0.14203716814517975, places=4) - self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.011851579882204533, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][0], 0.4780103862285614, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.0017992404755204916, places=4) - - shutil.rmtree(index_dir) - - def test_aggretriever_cocondenser_encoder_cmd(self): - index_dir = 'temp_index' - cmd = f'python -m pyserini.encode \ - input --corpus {self.test_file} \ - --fields text \ - output --embeddings {index_dir} \ - encoder --encoder castorini/aggretriever-cocondenser \ - --fields text \ - --batch 1 \ - --device cpu' - status = os.system(cmd) - self.assertEqual(status, 0) - - embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') - self.assertIsFile(embedding_json_fn) - - with open(embedding_json_fn) as f: - embeddings = [json.loads(line) for line in f] - - self.assertListEqual([entry["id"] for entry in embeddings], self.docids) - self.assertListEqual( - [entry["contents"] for entry in embeddings], - [entry.strip() for entry in self.texts], - ) - self.assertAlmostEqual(embeddings[0]['vector'][0], 0.4865410327911377, places=4) - self.assertAlmostEqual(embeddings[0]['vector'][-1], 0.006781343836337328, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][0], 0.32751473784446716, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.0014184381579980254, places=4) - - shutil.rmtree(index_dir) def test_onnx_encode_unicoil(self): temp_object = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'SpladePlusPlusEnsembleDistil', encoder_type='onnx') @@ -254,82 +104,6 @@ def test_onnx_encode_unicoil(self): temp_object1.close() del temp_object1 - def test_clip_encoder_cmd_text(self): - index_dir = 'temp_index' - cmd = f'python -m pyserini.encode \ - input --corpus {self.test_file} \ - --fields text \ - output --embeddings {index_dir} \ - encoder --encoder openai/clip-vit-base-patch32 \ - --fields text \ - --batch 1 --max-length 77 \ - --device cpu' - status = os.system(cmd) - self.assertEqual(status, 0) - - embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') - self.assertIsFile(embedding_json_fn) - - with open(embedding_json_fn) as f: - embeddings = [json.loads(line) for line in f] - - self.assertListEqual([entry["id"] for entry in embeddings], self.docids) - self.assertListEqual( - [entry["contents"] for entry in embeddings], - [entry.strip() for entry in self.texts], - ) - - self.assertAlmostEqual(embeddings[0]['vector'][0], 0.022726990282535553, places=4) - self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.02527175098657608, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][0], 0.00724585447460413, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.039689723402261734, places=4) - - shutil.rmtree(index_dir) - - def test_clip_encoder_cmd_image(self): - # special case setup for image data - docids = [] - texts = [] - test_file = 'tests/resources/sample_collection_jsonl_image/images.small.jsonl' - image_dir = pl.Path(test_file).parent - - with open(test_file) as f: - for line in f: - line = json.loads(line) - docids.append(line['id']) - texts.append(line['path']) - - index_dir = 'temp_index' - cmd = f'python -m pyserini.encode \ - input --corpus {test_file} \ - --fields path \ - output --embeddings {index_dir} \ - encoder --encoder openai/clip-vit-base-patch32 \ - --fields path \ - --batch 1 --multimodal --l2-norm \ - --device cpu' - status = os.system(cmd) - self.assertEqual(status, 0) - - embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') - self.assertIsFile(embedding_json_fn) - - with open(embedding_json_fn) as f: - embeddings = [json.loads(line) for line in f] - - self.assertListEqual([entry["id"] for entry in embeddings], docids) - self.assertListEqual( - [entry["contents"] for entry in embeddings], - [str(pl.Path(image_dir, entry.strip())) for entry in texts], - ) - - self.assertAlmostEqual(embeddings[0]['vector'][0], 0.003283643862232566, places=4) - self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.055951327085494995, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][0], 0.021012384444475174, places=4) - self.assertAlmostEqual(embeddings[2]['vector'][-1], -0.0011692788684740663, places=4) - - shutil.rmtree(index_dir) - @classmethod def tearDownClass(cls): os.remove(cls.tarball_name) diff --git a/tests/test_index_otf.py b/tests/test_index_otf.py index 9968009ff..42c12f678 100644 --- a/tests/test_index_otf.py +++ b/tests/test_index_otf.py @@ -15,12 +15,12 @@ # import os +import random import shutil import unittest -import random from typing import List -from pyserini.index.lucene import LuceneIndexer, IndexReader, JacksonObjectMapper +from pyserini.index.lucene import LuceneIndexer, LuceneIndexReader, JacksonObjectMapper from pyserini.search.lucene import JScoredDoc, LuceneSearcher @@ -151,7 +151,7 @@ def test_indexer_append1(self): indexer.add_doc_raw('{"id": "0", "contents": "Document 0"}') indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(1, stats['documents']) self.assertIsNotNone(reader.doc('0')) @@ -160,7 +160,7 @@ def test_indexer_append1(self): indexer.add_doc_raw('{"id": "1", "contents": "Document 1"}') indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('0')) @@ -172,7 +172,7 @@ def test_indexer_append2(self): indexer.add_doc_raw('{"id": "0", "contents": "Document 0"}') indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(1, stats['documents']) self.assertIsNotNone(reader.doc('0')) @@ -182,7 +182,7 @@ def test_indexer_append2(self): indexer.add_doc_raw('{"id": "1", "contents": "Document 1"}') indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(1, stats['documents']) self.assertIsNone(reader.doc('0')) @@ -193,7 +193,7 @@ def test_indexer_append2(self): indexer.add_doc_raw('{"id": "x", "contents": "Document x"}') indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNone(reader.doc('0')) @@ -206,7 +206,7 @@ def test_indexer_type_raw(self): indexer.add_doc_raw('{"id": "doc1", "contents": "document 1 contents"}') indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('doc0')) @@ -220,7 +220,7 @@ def test_indexer_type_raw_batch(self): indexer.add_batch_raw(batch) indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('doc0')) @@ -232,7 +232,7 @@ def test_indexer_type_dict(self): indexer.add_doc_dict({'id': 'doc1', 'contents': 'document 1 contents'}) indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('doc0')) @@ -246,7 +246,7 @@ def test_indexer_type_dict_batch(self): indexer.add_batch_dict(batch) indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('doc0')) @@ -260,7 +260,7 @@ def test_indexer_type_json(self): indexer.add_doc_json(mapper.createObjectNode().put('id', 'doc1').put('contents', 'document 1 contents')) indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('doc0')) @@ -275,7 +275,7 @@ def test_indexer_type_json_batch(self): indexer.add_batch_json(batch) indexer.close() - reader = IndexReader(self.tmp_dir) + reader = LuceneIndexReader(self.tmp_dir) stats = reader.stats() self.assertEqual(2, stats['documents']) self.assertIsNotNone(reader.doc('doc0')) diff --git a/tests/test_index_reader.py b/tests/test_index_reader.py index d736e7cc5..3592330fe 100644 --- a/tests/test_index_reader.py +++ b/tests/test_index_reader.py @@ -14,21 +14,22 @@ # limitations under the License. # +import heapq +import json import os import shutil import tarfile import unittest from random import randint from urllib.request import urlretrieve -import json -import heapq from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import MultinomialNB -from pyserini import analysis, search -from pyserini.index.lucene import IndexReader +from pyserini.analysis import get_lucene_analyzer +from pyserini.index.lucene import LuceneIndexReader from pyserini.pyclass import JString +from pyserini.search.lucene import LuceneSearcher, LuceneSimilarities from pyserini.vectorizer import BM25Vectorizer, TfidfVectorizer @@ -47,8 +48,8 @@ def setUp(self): tarball.close() self.index_path = os.path.join(self.index_dir, 'lucene9-index.cacm') - self.searcher = search.LuceneSearcher(self.index_path) - self.index_reader = IndexReader(self.index_path) + self.searcher = LuceneSearcher(self.index_path) + self.index_reader = LuceneIndexReader(self.index_path) self.temp_folders = [] @@ -72,7 +73,7 @@ def test_doc_vector_emoji_test(self): f'-generator DefaultLuceneDocumentGenerator ' + \ f'-threads 1 -input {self.emoji_corpus_path} -index {index_dir} -storeDocvectors' _ = os.system(cmd1) - temp_index_reader = IndexReader(index_dir) + temp_index_reader = LuceneIndexReader(index_dir) df, cf = temp_index_reader.get_term_counts('emoji') self.assertEqual(df, 1) @@ -157,7 +158,7 @@ def test_analyze(self): self.assertEqual(' '.join(self.index_reader.analyze('retrieval')), 'retriev') self.assertEqual(' '.join(self.index_reader.analyze('rapid retrieval, space economy')), 'rapid retriev space economi') - tokenizer = analysis.get_lucene_analyzer(stemming=False) + tokenizer = get_lucene_analyzer(stemming=False) self.assertEqual(' '.join(self.index_reader.analyze('retrieval', analyzer=tokenizer)), 'retrieval') self.assertEqual(' '.join(self.index_reader.analyze('rapid retrieval, space economy', analyzer=tokenizer)), 'rapid retrieval space economy') @@ -361,7 +362,7 @@ def test_query_doc_score_default(self): self.index_reader.compute_query_document_score(hits[i].docid, query), places=4) def test_query_doc_score_custom_similarity(self): - custom_bm25 = search.LuceneSimilarities.bm25(0.8, 0.2) + custom_bm25 = LuceneSimilarities.bm25(0.8, 0.2) queries = ['information retrieval', 'databases'] self.searcher.set_bm25(0.8, 0.2) @@ -375,7 +376,7 @@ def test_query_doc_score_custom_similarity(self): self.index_reader.compute_query_document_score( hits[i].docid, query, similarity=custom_bm25), places=4) - custom_qld = search.LuceneSimilarities.qld(500) + custom_qld = LuceneSimilarities.qld(500) self.searcher.set_qld(500) for query in queries: @@ -417,7 +418,7 @@ def compare_searcher(query): The query for search. """ # Search through documents BM25 dump - query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) + query_terms = self.index_reader.analyze(query, analyzer=get_lucene_analyzer()) heap = [] # heapq implements a min-heap, we can invert the values to have a max-heap for line in dump_file: @@ -469,7 +470,7 @@ def compare_searcher_quantized(query, tolerance=1): searching through documents in the dump is 2, then with a tolerance of 1 the ranking of the same document with Lucene searcher should be between 1-3. """ - query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) + query_terms = self.index_reader.analyze(query, analyzer=get_lucene_analyzer()) heap = [] for line in quantized_weights_file: doc = json.loads(line) diff --git a/tests/test_index_reader_lucene8.py b/tests/test_index_reader_lucene8.py index 353fbb3c9..7e5098ce3 100644 --- a/tests/test_index_reader_lucene8.py +++ b/tests/test_index_reader_lucene8.py @@ -23,8 +23,9 @@ from random import randint from urllib.request import urlretrieve -from pyserini import analysis, search -from pyserini.index.lucene import IndexReader +from pyserini.analysis import get_lucene_analyzer +from pyserini.index.lucene import LuceneIndexReader +from pyserini.search.lucene import LuceneSearcher, LuceneSimilarities class TestIndexUtilsForLucene8(unittest.TestCase): @@ -43,8 +44,8 @@ def setUp(self): tarball.close() self.index_path = os.path.join(self.index_dir, 'lucene-index.cacm') - self.searcher = search.LuceneSearcher(self.index_path) - self.index_reader = IndexReader(self.index_path) + self.searcher = LuceneSearcher(self.index_path) + self.index_reader = LuceneIndexReader(self.index_path) self.temp_folders = [] @@ -61,7 +62,7 @@ def test_query_doc_score_default(self): self.index_reader.compute_query_document_score(hits[i].docid, query), places=4) def test_query_doc_score_custom_similarity(self): - custom_bm25 = search.LuceneSimilarities.bm25(0.8, 0.2) + custom_bm25 = LuceneSimilarities.bm25(0.8, 0.2) queries = ['information retrieval', 'databases'] self.searcher.set_bm25(0.8, 0.2) @@ -75,7 +76,7 @@ def test_query_doc_score_custom_similarity(self): self.index_reader.compute_query_document_score( hits[i].docid, query, similarity=custom_bm25), places=4) - custom_qld = search.LuceneSimilarities.qld(500) + custom_qld = LuceneSimilarities.qld(500) self.searcher.set_qld(500) for query in queries: @@ -107,7 +108,7 @@ def compare_searcher(query): The query for search. """ # Search through documents BM25 dump - query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) + query_terms = self.index_reader.analyze(query, analyzer=get_lucene_analyzer()) heap = [] # heapq implements a min-heap, we can invert the values to have a max-heap for line in dump_file: @@ -159,7 +160,7 @@ def compare_searcher_quantized(query, tolerance=1): searching through documents in the dump is 2, then with a tolerance of 1 the ranking of the same document with Lucene searcher should be between 1-3. """ - query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) + query_terms = self.index_reader.analyze(query, analyzer=get_lucene_analyzer()) heap = [] for line in quantized_weights_file: doc = json.loads(line) diff --git a/tests/test_lucene_dense_search.py b/tests/test_lucene_dense_search.py index fcf55dddb..74d6675ae 100644 --- a/tests/test_lucene_dense_search.py +++ b/tests/test_lucene_dense_search.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import glob import os import unittest -from pyserini.util import get_cache_home from pyserini.search import get_topics from pyserini.search.lucene import LuceneHnswDenseSearcher, LuceneFlatDenseSearcher +from pyserini.util import get_cache_home class TestLuceneDenseSearch(unittest.TestCase): diff --git a/tests/test_nfcorpus_lucene.py b/tests/test_nfcorpus_lucene.py new file mode 100644 index 000000000..25b4935b2 --- /dev/null +++ b/tests/test_nfcorpus_lucene.py @@ -0,0 +1,68 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import subprocess +import tarfile +import unittest + +from random import randint +from urllib.request import urlretrieve + + +class TestNFCorpus(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.queries = 'tests/resources/nfcorpus-queries.tsv' + cls.qrels = 'tests/resources/nfcorpus-qrels.tsv' + + r = randint(0, 10000000) + + cls.sparse_index_url = 'https://github.com/castorini/anserini-data/raw/master/NFCorpus/lucene.nfcorpus.tar.gz' + cls.sparse_tarball_name = f'lucene.nfcorpus-{r}.tar.gz' + cls.sparse_index_dir = f'lucene.nfcorpus-{r}/' + + urlretrieve(cls.sparse_index_url, cls.sparse_tarball_name) + + tarball = tarfile.open(cls.sparse_tarball_name) + tarball.extractall(cls.sparse_index_dir) + tarball.close() + + def test_sparse_retrieval(self): + r = randint(0, 10000000) + run_file = f'run.{r}.txt' + cmd = f'python -m pyserini.search.lucene \ + --index {self.sparse_index_dir}/lucene.nfcorpus \ + --topics {self.queries} \ + --output {run_file} \ + --batch 32 --threads 4 \ + --hits 10 --bm25' + + os.system(cmd) + results = subprocess.check_output( + f'python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 {self.qrels} {run_file}', shell=True) + results = results.decode('utf-8').split('\n') + ndcg_line = results[-2] + ndcg_score = float(ndcg_line.split('\t')[-1]) + self.assertAlmostEqual(ndcg_score, 0.3405, places=5) + + os.remove(run_file) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.sparse_index_dir) + os.remove(cls.sparse_tarball_name) diff --git a/tests/test_prebuilt_index.py b/tests/test_prebuilt_index.py index de1bb5fa0..476493aa9 100644 --- a/tests/test_prebuilt_index.py +++ b/tests/test_prebuilt_index.py @@ -14,12 +14,13 @@ # limitations under the License. # -import requests import unittest -from pyserini.pyclass import autoclass +import requests + from pyserini.prebuilt_index_info import TF_INDEX_INFO, IMPACT_INDEX_INFO, \ LUCENE_HNSW_INDEX_INFO, LUCENE_FLAT_INDEX_INFO, FAISS_INDEX_INFO +from pyserini.pyclass import autoclass class TestPrebuiltIndexes(unittest.TestCase): diff --git a/tests/test_search.py b/tests/test_search.py index 278796899..1086d0f6d 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -22,8 +22,8 @@ from typing import List, Dict from urllib.request import urlretrieve -from pyserini.search.lucene import LuceneSearcher, JScoredDoc from pyserini.index.lucene import Document +from pyserini.search.lucene import LuceneSearcher, JScoredDoc class TestSearch(unittest.TestCase): diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index e719e4e37..9fd57bfbe 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -17,6 +17,7 @@ import unittest from transformers import BertTokenizer, T5Tokenizer, AutoTokenizer + from pyserini.analysis import Analyzer, get_lucene_analyzer diff --git a/tests/test_trectools.py b/tests/test_trectools.py index 4dbe5e9ea..8eaf869cb 100644 --- a/tests/test_trectools.py +++ b/tests/test_trectools.py @@ -16,8 +16,8 @@ import filecmp import os -import unittest import subprocess +import unittest from pyserini.trectools import TrecRun, Qrels, RescoreMethod