Skip to content

Commit

Permalink
Improve behavior of tokenize/retrieve functions when faced with unkno…
Browse files Browse the repository at this point in the history
…wn tokens
  • Loading branch information
xhluca committed Oct 31, 2024
1 parent 266216a commit aa31a23
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 54 deletions.
152 changes: 99 additions & 53 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
except ImportError:
_retrieve_numba_functional = None


def _faketqdm(iterable, *args, **kwargs):
return iterable

Expand All @@ -48,8 +49,8 @@ def _faketqdm(iterable, *args, **kwargs):
_build_nonoccurrence_array,
)

debug_logger = logging.getLogger("bm25s")
debug_logger.setLevel(logging.DEBUG)
logger = logging.getLogger("bm25s")
logger.setLevel(logging.DEBUG)


class Results(NamedTuple):
Expand Down Expand Up @@ -106,24 +107,25 @@ def is_list_of_list_of_type(obj, type_=int):

return True


def _is_tuple_of_list_of_tokens(obj):
if not isinstance(obj, tuple):
return False

if len(obj) == 0:
return False

first_elem = obj[0]
if not isinstance(first_elem, list):
return False

if len(first_elem) == 0:
return False
first_token = first_elem[0]

first_token = first_elem[0]
if not isinstance(first_token, str):
return False

return True


Expand Down Expand Up @@ -171,7 +173,7 @@ def __init__(
The corpus of documents. This is optional and is used for saving the corpus
to the snapshot. We expect the corpus to be a list of dictionaries, where each
dictionary represents a document.
backend : str
The backend used during retrieval. By default, it uses the numpy backend, which
only requires numpy and scipy as dependencies. You can also select `backend="numba"`
Expand Down Expand Up @@ -232,7 +234,7 @@ def _compute_relevance_from_scores(
by using the BM25 scores that have been precomputed in the BM25 eager index.
It is used by the `get_scores_from_ids` method, which makes use of the precomputed
scores assigned as attributes of the BM25 object.
Parameters
----------
data (np.ndarray)
Expand All @@ -252,7 +254,7 @@ def _compute_relevance_from_scores(
-------
np.ndarray
Array of BM25 relevance scores for a given query.
Note
----
This function was optimized by the baguetter library. The original implementation can be found at:
Expand Down Expand Up @@ -429,18 +431,16 @@ def index(
inferred_corpus_obj = self._infer_corpus_object(corpus)

if inferred_corpus_obj == "tokens":
debug_logger.log(msg="Building index from tokens", level=logging.DEBUG)
logger.debug(msg="Building index from tokens")
scores, vocab_dict = self.build_index_from_tokens(
corpus, leave_progress=leave_progress, show_progress=show_progress
)
else:
if inferred_corpus_obj == "tuple":
debug_logger.log(msg="Building index from IDs", level=logging.DEBUG)
logger.debug(msg="Building index from IDs")
corpus_token_ids, vocab_dict = corpus
elif inferred_corpus_obj == "object":
debug_logger.log(
msg="Building index from IDs objects", level=logging.DEBUG
)
logger.debug(msg="Building index from IDs objects")
corpus_token_ids = corpus.ids
vocab_dict = corpus.vocab
else:
Expand All @@ -467,8 +467,10 @@ def get_tokens_ids(self, query_tokens: List[str]) -> List[int]:
return [
self.vocab_dict[token] for token in query_tokens if token in self.vocab_dict
]

def get_scores_from_ids(self, query_tokens_ids: List[int], weight_mask=None) -> np.ndarray:

def get_scores_from_ids(
self, query_tokens_ids: List[int], weight_mask=None
) -> np.ndarray:
data = self.scores["data"]
indices = self.scores["indices"]
indptr = self.scores["indptr"]
Expand All @@ -478,6 +480,13 @@ def get_scores_from_ids(self, query_tokens_ids: List[int], weight_mask=None) ->
int_dtype = np.dtype(self.int_dtype)
query_tokens_ids = np.asarray(query_tokens_ids, dtype=int_dtype)

max_token_id = query_tokens_ids.max()
if max_token_id >= len(indptr) - 1:
raise ValueError(
f"The maximum token ID in the query ({max_token_id}) is higher than the number of tokens in the index."
"This likely means that the query contains tokens that are not in the index."
)

scores = self._compute_relevance_from_scores(
data=data,
indptr=indptr,
Expand All @@ -499,18 +508,22 @@ def get_scores_from_ids(self, query_tokens_ids: List[int], weight_mask=None) ->

return scores

def get_scores(self, query_tokens_single: List[str], weight_mask=None) -> np.ndarray:
def get_scores(
self, query_tokens_single: List[str], weight_mask=None
) -> np.ndarray:
if not isinstance(query_tokens_single, list):
raise ValueError("The query_tokens must be a list of tokens.")

if isinstance(query_tokens_single[0], str):
query_tokens_ids = self.get_tokens_ids(query_tokens_single)
elif isinstance(query_tokens_single[0], int):
# already are token IDs, no need to convert
query_tokens_ids = query_tokens_single
else:
raise ValueError("The query_tokens must be a list of tokens or a list of token IDs.")

raise ValueError(
"The query_tokens must be a list of tokens or a list of token IDs."
)

return self.get_scores_from_ids(query_tokens_ids, weight_mask=weight_mask)

def _get_top_k_results(
Expand All @@ -525,11 +538,20 @@ def _get_top_k_results(
This function is used to retrieve the top-k results for a single query.
Since it's a hidden function, the user should not call it directly and
may change in the future. Please use the `retrieve` function instead.
"""
scores_q = self.get_scores(query_tokens_single, weight_mask=weight_mask)
if backend.startswith('numba'):
"""
if len(query_tokens_single) == 0:
logger.info(
msg="The query is empty. This will result in a zero score for all documents."
)
scores_q = np.zeros(self.scores["num_docs"], dtype=self.dtype)
else:
scores_q = self.get_scores(query_tokens_single, weight_mask=weight_mask)

if backend.startswith("numba"):
if selection_jit is None:
raise ImportError("Numba is not installed. Please install numba to use the numba backend.")
raise ImportError(
"Numba is not installed. Please install numba to use the numba backend."
)
topk_scores, topk_indices = selection_jit.topk(
scores_q, k=k, sorted=sorted, backend=backend
)
Expand All @@ -540,7 +562,6 @@ def _get_top_k_results(

return topk_scores, topk_indices


def retrieve(
self,
query_tokens: Union[List[List[str]], tokenization.Tokenized],
Expand Down Expand Up @@ -603,10 +624,24 @@ def retrieve(
backend_selection : str
The backend to use for the top-k retrieval. Choose from "auto", "numpy", "jax".
If "auto", it will use JAX if it is available, otherwise it will use numpy.
weight_mask : np.ndarray
A weight mask to filter the documents. If provided, the scores for the masked
documents will be set to 0 to avoid returning them in the results.
Returns
-------
Results or np.ndarray
If `return_as="tuple"`, a named tuple with two fields will be returned: `documents` and `scores`.
If `return_as="documents"`, only the retrieved documents (or indices if `corpus` is not provided) will be returned.
Raises
------
ValueError
If the `query_tokens` is not a list of list of tokens (str) or a tuple of two lists: the first list is the list of unique token IDs, and the second list is the list of token IDs for each document.
ImportError
If the numba backend is selected but numba is not installed.
"""
allowed_return_as = ["tuple", "documents"]

Expand All @@ -617,22 +652,26 @@ def retrieve(

if n_threads == -1:
n_threads = os.cpu_count()

# if it's a list of list of tokens ids (int), we remove any integer not in the vocab_dict
if is_list_of_list_of_type(query_tokens, type_=int):
query_tokens_filtered = []
for query in query_tokens:
query_filtered = [token_id for token_id in query if token_id in self.vocab_dict]
query_filtered = [
token_id for token_id in query if token_id in self.vocab_dict
]
if len(query_filtered) == 0:
if "" not in self.vocab_dict:
self.vocab_dict[""] = max(self.vocab_dict.values()) + 1
query_filtered = [self.vocab_dict[""]]

query_tokens_filtered.append(query_filtered)

query_tokens = query_tokens_filtered

if isinstance(query_tokens, tuple) and not _is_tuple_of_list_of_tokens(query_tokens):

if isinstance(query_tokens, tuple) and not _is_tuple_of_list_of_tokens(
query_tokens
):
if len(query_tokens) != 2:
msg = (
"Expected a list of string or a tuple of two elements: the first element is the "
Expand All @@ -655,28 +694,32 @@ def retrieve(

if isinstance(query_tokens, tokenization.Tokenized):
query_tokens = tokenization.convert_tokenized_to_string_list(query_tokens)

corpus = corpus if corpus is not None else self.corpus

if weight_mask is not None:
if not isinstance(weight_mask, np.ndarray):
raise ValueError("weight_mask must be a numpy array.")

# check if weight_mask is a 1D array, if not raise an error
if weight_mask.ndim != 1:
raise ValueError("weight_mask must be a 1D array.")

# check if the length of the weight_mask is the same as the length of the corpus
if len(weight_mask) != self.scores["num_docs"]:
if len(weight_mask) != self.scores["num_docs"]:
raise ValueError(
"The length of the weight_mask must be the same as the length of the corpus."
)

if self.backend == "numba":
if _retrieve_numba_functional is None:
raise ImportError("Numba is not installed. Please install numba wiith `pip install numba` to use the numba backend.")

backend_selection = "numba" if backend_selection == "auto" else backend_selection
raise ImportError(
"Numba is not installed. Please install numba wiith `pip install numba` to use the numba backend."
)

backend_selection = (
"numba" if backend_selection == "auto" else backend_selection
)
# if is list of list of int
if is_list_of_list_of_type(query_tokens, type_=int):
query_tokens_ids = query_tokens
Expand All @@ -686,7 +729,7 @@ def retrieve(
raise ValueError(
"The query_tokens must be a list of list of tokens (str for stemmed words, int for token ids matching corpus) or a tuple of two lists: the first list is the list of unique token IDs, and the second list is the list of token IDs for each document."
)

res = _retrieve_numba_functional(
query_tokens_ids=query_tokens_ids,
scores=self.scores,
Expand All @@ -697,27 +740,30 @@ def retrieve(
show_progress=show_progress,
leave_progress=leave_progress,
n_threads=n_threads,
chunksize=None, # chunksize is ignored in the numba backend
backend_selection=backend_selection, # backend_selection is ignored in the numba backend
chunksize=None, # chunksize is ignored in the numba backend
backend_selection=backend_selection, # backend_selection is ignored in the numba backend
dtype=self.dtype,
int_dtype=self.int_dtype,
nonoccurrence_array=self.nonoccurrence_array
nonoccurrence_array=self.nonoccurrence_array,
)

if return_as == "tuple":
return Results(documents=res[0], scores=res[1])
else:
return res


tqdm_kwargs = {
"total": len(query_tokens),
"desc": "BM25S Retrieve",
"leave": leave_progress,
"disable": not show_progress,
}
topk_fn = partial(
self._get_top_k_results, k=k, sorted=sorted, backend=backend_selection, weight_mask=weight_mask
self._get_top_k_results,
k=k,
sorted=sorted,
backend=backend_selection,
weight_mask=weight_mask,
)

if n_threads == 0:
Expand Down Expand Up @@ -882,7 +928,7 @@ def save(
utils.corpus.save_mmindex(mmidx, path=save_dir / corpus_name)

def load_scores(
self,
self,
save_dir,
data_name="data.csc.index.npy",
indices_name="indices.csc.index.npy",
Expand Down Expand Up @@ -936,7 +982,6 @@ def load_scores(

self.scores = scores


@classmethod
def load(
cls,
Expand Down Expand Up @@ -995,7 +1040,7 @@ def load(
allow_pickle : bool
If True, the arrays will be loaded using pickle. If False, the arrays will be loaded
in a more efficient format, but they will not be readable by older versions of numpy.
load_vocab : bool
If True, the vocab dictionary will be loaded from the `vocab_name` file. If False, the vocab dictionary
will not be loaded, and the `vocab_dict` attribute of the BM25 object will be set to None.
Expand All @@ -1018,7 +1063,7 @@ def load(
vocab_dict: dict = json_functions.loads(f.read())
else:
vocab_dict = None

original_version = params.pop("version", None)
num_docs = params.pop("num_docs", None)

Expand Down Expand Up @@ -1091,5 +1136,6 @@ def activate_numba_scorer(self):

from .scoring import _compute_relevance_from_scores_jit_ready

self._compute_relevance_from_scores = njit(_compute_relevance_from_scores_jit_ready)

self._compute_relevance_from_scores = njit(
_compute_relevance_from_scores_jit_ready
)
9 changes: 8 additions & 1 deletion bm25s/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,14 @@ def streaming_tokenize(

if len(doc_ids) == 0 and allow_empty is True:
if update_vocab is True and "" not in self.word_to_id:
self.word_to_id[""] = max(self.word_to_id.values(), default=0) + 1
idx = max(self.word_to_id.values(), default=-1) + 1
self.word_to_id[""] = idx

if using_stemmer:
if "" not in self.word_to_stem:
self.word_to_stem[""] = ""
if "" not in self.stem_to_sid:
self.stem_to_sid[""] = idx

# get the ID for the empty string
if "" in self.word_to_id:
Expand Down
Loading

0 comments on commit aa31a23

Please sign in to comment.