Skip to content

Commit

Permalink
Improve tokenizer (#51)
Browse files Browse the repository at this point in the history
* Add a tokenizer class (WIP)

* fix word_to_wid logic and add example for using tokenizer class

* WIP changes to make token ids valid inputs of retrieve, still need to be thoroughly tested

* add todo to example

* Major refactoring of tokenizer dclass

* Minor QOL improvements

* Refactor streaming_tokenize to be faster by reducing unecessary set checks

* Remove _word_to_wid to simplify vocab design. Now, word_to_id is updated when stemmer is not used

* Update beir.py utils to use new URL

* Remove unused function, lint code, add test cases

* Add example of using the tokenizer

* Rename example

* Add details about the new tokenizer class in readme

* Update class to test tokenize in int, ids, strings, tuple
  • Loading branch information
xhluca authored Sep 8, 2024
1 parent 8b36bdd commit 0a49c62
Show file tree
Hide file tree
Showing 6 changed files with 570 additions and 58 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,48 @@ retriever = bm25s.BM25.load("bm25s_very_big_index", mmap=True)

For an example of how to use retrieve using the `mmap=True` mode, check out [`examples/retrieve_nq.py`](examples/retrieve_nq.py).


## Tokenization

In addition to using the simple function `bm25s.tokenize`, you can also use the `Tokenizer` class to customize the tokenization process. This is useful when you want to use a different tokenizer, or when you want to use a different tokenization process for queries and documents:

```python
from bm25s.tokenization import Tokenizer

corpus = [
"a cat is a feline and likes to purr",
"a dog is the human's best friend and loves to play",
"a bird is a beautiful animal that can fly",
"a fish is a creature that lives in water and swims",
]

# Pick your favorite stemmer, and pass
stemmer = None
stopwords = []
splitter = lambda x: x.split() # function or regex pattern
# Create a tokenizer
tokenizer = Tokenizer(
stemmer=stemmer, stopwords=stopwords, splitter=splitter
)

corpus_tokens = tokenizer.tokenize(corpus)

# let's see what the tokens look like
print("tokens:", corpus_tokens)
print("vocab:", tokenizer.get_vocab_dict())

# note: the vocab dict will either be a dict of `word -> id` if you don't have a stemmer, and a dict of `stemmed word -> stem id` if you do.
```

You can find advanced examples in [examples/tokenizer_class.py](examples/tokenizer_class.py), including how to:
* Pass a stemmer, stopwords, and splitter function/regex pattern
* Control whether vocabulary is updated by `tokenizer.tokenize` calls or not (by default, it will only be updated during the first call)
* Reset the tokenizer to its initial state with `tokenizer.reset_vocab()`
* Use the tokenizer in generator mode to save memory by `yield`ing one document at a time.
* Pass different outputs of the tokenizer to the `BM25.retrieve` function.



## Variants

You can use the following variants of BM25 in `bm25s` (see [Kamphuis et al. 2020](https://link.springer.com/chapter/10.1007/978-3-030-45442-5_4) for more details):
Expand Down
45 changes: 42 additions & 3 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ def get_unique_tokens(
unique_tokens.update(doc_tokens)
return unique_tokens


def is_list_of_list_of_type(obj, type_=int):
if not isinstance(obj, list):
return False

if len(obj) == 0:
return False

first_elem = obj[0]
if not isinstance(first_elem, list):
return False

if len(first_elem) == 0:
return False

first_token = first_elem[0]
if not isinstance(first_token, type_):
return False

return True

def _is_tuple_of_list_of_tokens(obj):
if not isinstance(obj, tuple):
return False
Expand Down Expand Up @@ -467,7 +488,17 @@ def get_scores_from_ids(self, query_tokens_ids: List[int], weight_mask=None) ->
return scores

def get_scores(self, query_tokens_single: List[str], weight_mask=None) -> np.ndarray:
query_tokens_ids = self.get_tokens_ids(query_tokens_single)
if not isinstance(query_tokens_single, list):
raise ValueError("The query_tokens must be a list of tokens.")

if isinstance(query_tokens_single[0], str):
query_tokens_ids = self.get_tokens_ids(query_tokens_single)
elif isinstance(query_tokens_single[0], int):
# already are token IDs, no need to convert
query_tokens_ids = query_tokens_single
else:
raise ValueError("The query_tokens must be a list of tokens or a list of token IDs.")

return self.get_scores_from_ids(query_tokens_ids, weight_mask=weight_mask)

def _get_top_k_results(
Expand Down Expand Up @@ -575,7 +606,6 @@ def retrieve(
if n_threads == -1:
n_threads = os.cpu_count()


if isinstance(query_tokens, tuple) and not _is_tuple_of_list_of_tokens(query_tokens):
if len(query_tokens) != 2:
msg = (
Expand Down Expand Up @@ -621,7 +651,16 @@ def retrieve(
raise ImportError("Numba is not installed. Please install numba wiith `pip install numba` to use the numba backend.")

backend_selection = "numba" if backend_selection == "auto" else backend_selection
query_tokens_ids = [self.get_tokens_ids(q) for q in query_tokens]
# if is list of list of int
if is_list_of_list_of_type(query_tokens, type_=int):
query_tokens_ids = query_tokens
elif is_list_of_list_of_type(query_tokens, type_=str):
query_tokens_ids = [self.get_tokens_ids(q) for q in query_tokens]
else:
raise ValueError(
"The query_tokens must be a list of list of tokens (str for stemmed words, int for token ids matching corpus) or a tuple of two lists: the first list is the list of unique token IDs, and the second list is the list of token IDs for each document."
)

res = _retrieve_numba_functional(
query_tokens_ids=query_tokens_ids,
scores=self.scores,
Expand Down
Loading

0 comments on commit 0a49c62

Please sign in to comment.