Skip to content

Commit

Permalink
Bug fix and add link (#28)
Browse files Browse the repository at this point in the history
* Update HF readme template

* Fix type for def retrieve

* Add tests for retrieve function

* Update retrieve to take tuples
  • Loading branch information
xhluca authored Jul 10, 2024
1 parent a30b1ed commit 7cf3d11
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 8 deletions.
54 changes: 48 additions & 6 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def _faketqdm(iterable, *args, **kwargs):
return iterable


if os.environ.get("DISABLE_TQDM", False):
tqdm = _faketqdm
# if can't import tqdm, use a fake tqdm
Expand Down Expand Up @@ -65,6 +66,26 @@ def get_unique_tokens(
unique_tokens.update(doc_tokens)
return unique_tokens

def _is_tuple_of_list_of_tokens(obj):
if not isinstance(obj, tuple):
return False

if len(obj) == 0:
return False

first_elem = obj[0]
if not isinstance(first_elem, list):
return False

if len(first_elem) == 0:
return False

first_token = first_elem[0]
if not isinstance(first_token, str):
return False

return True


def _calculate_scores_with_arrays(
data, indptr, indices, num_docs, query_tokens_ids, dtype
Expand Down Expand Up @@ -416,11 +437,11 @@ def _get_top_k_results(

def retrieve(
self,
query_tokens: List[List[str]],
query_tokens: Union[List[List[str]], tokenization.Tokenized],
corpus: List[Any] = None,
k: int = 10,
sorted: bool = True,
return_as: bool = "tuple",
return_as: str = "tuple",
show_progress: bool = True,
leave_progress: bool = False,
n_threads: int = 0,
Expand All @@ -432,7 +453,7 @@ def retrieve(
Parameters
----------
query_tokens : List[List[str]]
query_tokens : List[List[str]] or bm25s.tokenization.Tokenized
List of list of tokens for each query. If a Tokenized object is provided,
it will be converted to a list of list of tokens.
Expand All @@ -452,7 +473,7 @@ def retrieve(
sorted : bool
If True, the function will sort the results by score before returning them.
return_as : bool
return_as : str
If return_as="tuple", a named tuple with two fields will be returned:
`documents` and `scores`, which can be accessed as `result.documents` and
`result.scores`, or by unpacking, e.g. `documents, scores = retrieve(...)`.
Expand All @@ -461,7 +482,7 @@ def retrieve(
show_progress : bool
If True, a progress bar will be shown. If False, no progress bar will be shown.
leave_progress : bool
If True, the progress bars will remain after the function completes.
Expand All @@ -485,6 +506,28 @@ def retrieve(

if n_threads == -1:
n_threads = os.cpu_count()


if isinstance(query_tokens, tuple) and not _is_tuple_of_list_of_tokens(query_tokens):
if len(query_tokens) != 2:
msg = (
"Expected a list of string or a tuple of two elements: the first element is the "
"list of unique token IDs, "
"and the second element is the list of token IDs for each document."
f"Found {len(query_tokens)} elements instead."
)
raise ValueError(msg)
else:
ids, vocab = query_tokens
if not isinstance(ids, Iterable):
raise ValueError(
"The first element of the tuple passed to retrieve must be an iterable."
)
if not isinstance(vocab, dict):
raise ValueError(
"The second element of the tuple passed to retrieve must be a dictionary."
)
query_tokens = tokenization.Tokenized(ids=ids, vocab=vocab)

if isinstance(query_tokens, tokenization.Tokenized):
query_tokens = tokenization.convert_tokenized_to_string_list(query_tokens)
Expand Down Expand Up @@ -750,7 +793,6 @@ def load(
"num_docs": params.pop("num_docs", None),
}


bm25_obj = cls(**params)
bm25_obj.scores = scores
bm25_obj.vocab_dict = vocab_dict
Expand Down
9 changes: 7 additions & 2 deletions bm25s/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@
This is a BM25S index created with the [`bm25s` library](https://github.com/xhluca/bm25s) (version `{version}`), an ultra-fast implementation of BM25. It can be used for lexical retrieval tasks.
💻[BM25S GitHub Repository](https://github.com/xhluca/bm25s)\\
🌐[BM25S Homepage](https://bm25s.github.io)
BM25S Related Links:
* 🏠[Homepage](https://bm25s.github.io)
* 💻[GitHub Repository](https://github.com/xhluca/bm25s)
* 🤗[Blog Post](https://huggingface.co/blog/xhluca/bm25s)
* 📝[Technical Report](https://arxiv.org/abs/2407.03618)
## Installation
Expand Down
99 changes: 99 additions & 0 deletions tests/quick/test_retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import shutil
from pathlib import Path
import unittest
import tempfile

import numpy as np
import bm25s
import Stemmer # optional: for stemming

from .. import BM25TestCase

class TestBM25SLoadingSaving(BM25TestCase):
@classmethod
def setUpClass(cls):

# Create your corpus here
corpus = [
"a cat is a feline and likes to purr",
"a dog is the human's best friend and loves to play",
"a bird is a beautiful animal that can fly",
"a fish is a creature that lives in water and swims",
]

# optional: create a stemmer
stemmer = Stemmer.Stemmer("english")

# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)

# Create the BM25 model and index the corpus
retriever = bm25s.BM25(method='bm25+')
retriever.index(corpus_tokens)

# Save the retriever to temp dir
cls.retriever = retriever
cls.corpus = corpus
cls.corpus_tokens = corpus_tokens
cls.stemmer = stemmer

def test_retrieve(self):
ground_truth = np.array([[0, 2]])

# first, try with default mode
query = "a cat is a feline, it's sometimes beautiful but cannot fly"
query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True)

# retrieve the top 2 documents
results = self.retriever.retrieve(query_tokens_obj, k=2).documents

# assert that the retrieved indices are correct
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

# now, try tokenizing with text tokens
query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False)
results = self.retriever.retrieve(query_tokens_texts, k=2).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

# now, try to pass a tuple of tokens
ids, vocab = query_tokens_obj
query_tokens_tuple = (ids, vocab)
results = self.retriever.retrieve(query_tokens_tuple, k=2).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

# finally, try to pass a 2-tuple of tokens with text tokens to "try to trick the system"
queries_as_tuple = (query_tokens_texts[0], query_tokens_texts[0])
# only retrieve 1 document
ground_truth = np.array([[0], [0]])
results = self.retriever.retrieve(queries_as_tuple, k=1).documents
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}")

def test_failure_of_bad_tuple(self):
# try to pass a tuple of tokens with different lengths
query = "a cat is a feline, it's sometimes beautiful but cannot fly"
query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True)
query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False)
ids, vocab = query_tokens_obj
query_tokens_tuple = (vocab, ids)

with self.assertRaises(ValueError):
self.retriever.retrieve(query_tokens_tuple, k=2)

# now, test if there's vocab twice or ids twice
query_tokens_tuple = (ids, ids)
with self.assertRaises(ValueError):
self.retriever.retrieve(query_tokens_tuple, k=2)

# finally, test only passing vocab
query_tokens_tuple = (vocab, )
with self.assertRaises(ValueError):
self.retriever.retrieve(query_tokens_tuple, k=2)





@classmethod
def tearDownClass(cls):
pass

0 comments on commit 7cf3d11

Please sign in to comment.