Skip to content

Commit

Permalink
Add new examples using numba
Browse files Browse the repository at this point in the history
  • Loading branch information
xhluca committed Sep 16, 2024
1 parent 8cc651e commit 942ab0b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
51 changes: 51 additions & 0 deletions examples/index_and_retrieve_with_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
# Example: Use Numba to speed up the retrieval process
```bash
pip install "bm25s[full]" numba
```
To build an index, please refer to the `examples/index_and_upload_to_hf.py` script.
Now, to run this script, execute:
```bash
python examples/retrieve_with_numba.py
```
"""
import os
import Stemmer

import bm25s.hf
import bm25s

def main(dataset='scifact', dataset_dir='./datasets'):
queries = [
"Is chemotherapy effective for treating cancer?",
"Is Cardiac injury is common in critical cases of COVID-19?",
]

bm25s.utils.beir.download_dataset(dataset=dataset, save_dir=dataset_dir)
corpus: dict = bm25s.utils.beir.load_corpus(dataset=dataset, save_dir=dataset_dir)
corpus_records = [
{'id': k, 'title': v["title"], 'text': v["text"]} for k, v in corpus.items()
]
corpus_lst = [r["title"] + " " + r["text"] for r in corpus_records]

retriever = bm25s.BM25(corpus=corpus_records, backend='numba')
retriever.index(corpus_lst)
# corpus=corpus_records is optional, only used when you are calling retrieve and want to return the documents

# Tokenize the queries
stemmer = Stemmer.Stemmer("english")
tokenizer = bm25s.tokenization.Tokenizer(stemmer=stemmer)
queries_tokenized = tokenizer.tokenize(queries)
# Retrieve the top-k results
results = retriever.retrieve(queries_tokenized, k=3)
# show first results
result = results.documents[0]
print(f"First score (# 1 result): {results.scores[0, 0]:.4f}")
print(f"First result id (# 1 result): {result[0]['id']}")
print(f"First result title (# 1 result): {result[0]['title']}")

if __name__ == "__main__":
main()
File renamed without changes.
45 changes: 45 additions & 0 deletions examples/retrieve_with_numba_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
# Example: Use Numba to speed up the retrieval process
```bash
pip install "bm25s[full]" numba
```
To build an index, please refer to the `examples/index_and_upload_to_hf.py` script.
Now, to run this script, execute:
```bash
python examples/retrieve_with_numba.py
```
"""
import os
import Stemmer

import bm25s.hf

def main(repo_name="xhluca/bm25s-fiqa-index"):
queries = [
"Is chemotherapy effective for treating cancer?",
"Is Cardiac injury is common in critical cases of COVID-19?",
]

retriever = bm25s.hf.BM25HF.load_from_hub(
repo_name, load_corpus=False, mmap=False
)

retriever.backend = "numba" # this can also be set during initialization of the retriever

# Tokenize the queries
stemmer = Stemmer.Stemmer("english")
tokenizer = bm25s.tokenization.Tokenizer(stemmer=stemmer)
queries_tokenized = tokenizer.tokenize(queries)

# Retrieve the top-k results
results = retriever.retrieve(queries_tokenized, k=3)
# show first results
result = results.documents[0]
print(f"First score (# 1 result): {results.scores[0, 0]:.4f}")
print(f"First result (# 1 result): {result[0]}")

if __name__ == "__main__":
main()

0 comments on commit 942ab0b

Please sign in to comment.