Skip to content

Commit

Permalink
Added support for saving and loading non ASCII chars in corpus and vo…
Browse files Browse the repository at this point in the history
…cab (#86)

* Added support for saving and loading non ASCII chars in corpus and vocab

* Added kwargs in json_functions for ensure_ascii parameter

* Added ensure_ascii = False in tokenization

* Added the test in #79

* Remodified the assertion as per github action
  • Loading branch information
IssacXid authored Nov 26, 2024
1 parent f161fba commit db53725
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 13 deletions.
8 changes: 4 additions & 4 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,9 @@ def save(
# Save the vocab dictionary
vocab_path = save_dir / vocab_name

with open(vocab_path, "w") as f:
f.write(json_functions.dumps(self.vocab_dict))

with open(vocab_path, "wt", encoding='utf-8') as f:
f.write(json_functions.dumps(self.vocab_dict, ensure_ascii=False))
# Save the parameters
params_path = save_dir / params_name
params = dict(
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def load(
# Load the vocab dictionary
if load_vocab:
vocab_path = save_dir / vocab_name
with open(vocab_path, "r") as f:
with open(vocab_path, "r",encoding='utf-8') as f:
vocab_dict: dict = json_functions.loads(f.read())
else:
vocab_dict = None
Expand Down
6 changes: 3 additions & 3 deletions bm25s/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def save_vocab(self, save_dir: str, vocab_name: str = "vocab.tokenizer.json"):
path = save_dir / vocab_name

save_dir.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
with open(path, "w", encoding='utf-8') as f:
d = {
"word_to_stem": self.word_to_stem,
"stem_to_sid": self.stem_to_sid,
"word_to_id": self.word_to_id,
}
f.write(json_functions.dumps(d))
f.write(json_functions.dumps(d, ensure_ascii=False))

def load_vocab(self, save_dir: str, vocab_name: str = "vocab.tokenizer.json"):
"""
Expand All @@ -150,7 +150,7 @@ def load_vocab(self, save_dir: str, vocab_name: str = "vocab.tokenizer.json"):
"""
path = Path(save_dir) / vocab_name

with open(path, "r") as f:
with open(path, "r", encoding='utf-8') as f:
d = json_functions.loads(f.read())
self.word_to_stem = d["word_to_stem"]
self.stem_to_sid = d["stem_to_sid"]
Expand Down
19 changes: 13 additions & 6 deletions bm25s/utils/json_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
ORJSON_AVAILABLE = False


def dumps_with_builtin(d: dict) -> str:
return json.dumps(d)
def dumps_with_builtin(d: dict, **kwargs) -> str:
return json.dumps(d, **kwargs)

def dumps_with_orjson(d: dict) -> str:
return orjson.dumps(d).decode('utf-8')
def dumps_with_orjson(d: dict, **kwargs) -> str:
if kwargs.get("ensure_ascii", True):
# Simulate `ensure_ascii=True` by escaping non-ASCII characters
return orjson.dumps(d).decode("utf-8").encode("ascii", "backslashreplace").decode("utf-8")
# Ignore other kwargs not supported by orjson
return orjson.dumps(d).decode("utf-8")

if ORJSON_AVAILABLE:
dumps = dumps_with_orjson
def dumps(d: dict, **kwargs) -> str:
return dumps_with_orjson(d, **kwargs)
loads = orjson.loads
else:
dumps = dumps_with_builtin
def dumps(d: dict, **kwargs) -> str:
return dumps_with_builtin(d, **kwargs)
loads = json.loads


50 changes: 50 additions & 0 deletions tests/core/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def setUpClass(cls):
"a dog is the human's best friend and loves to play",
"a bird is a beautiful animal that can fly",
"a fish is a creature that lives in water and swims",
"שלום חברים, איך אתם היום?",
"El café está muy caliente",
"今天的天气真好!",
"Как дела?",
"Türkçe öğreniyorum."
]

# optional: create a stemmer
Expand Down Expand Up @@ -119,6 +124,51 @@ def test_d_load_no_orjson(self):
self.test_b_load()


@classmethod
def tearDownClass(cls):
# remove the temp dir with rmtree
shutil.rmtree(cls.tmpdirname)


class TestBM25SNonASCIILoadingSaving(unittest.TestCase):
orjson_should_not_be_installed = False
orjson_should_be_installed = True

@classmethod
def setUpClass(cls):
# check that import orjson fails
import bm25s

cls.text =["Thanks for your great work!"] # this works fine
cls.text = ['שלום חברים'] # this crashes!

# create a vocabulary
tokens = [ t.split() for t in cls.text ]
unique_tokens = set([item for sublist in tokens for item in sublist])
vocab_token2id = {token: i for i, token in enumerate(unique_tokens)}

# create a tokenized corpus
token_ids = [ [vocab_token2id[token] for token in text_tokens if token in vocab_token2id] for text_tokens in tokens ]
corpus_tokens = bm25s.tokenization.Tokenized(ids=token_ids, vocab=vocab_token2id)

# create a retriever
cls.retriever = bm25s.BM25()
cls.retriever.index(corpus_tokens)
cls.tmpdirname = tempfile.mkdtemp()


def setUp(self):
# verify that orjson is properly installed
try:
import orjson
except ImportError:
self.fail("orjson should be installed to run this test.")

def test_a_save_and_load(self):
# both of these fail: UnicodeEncodeError: 'charmap' codec can't encode characters in position 2-6: character maps to <undefined>
self.retriever.save(self.tmpdirname, corpus=self.text)
self.retriever.load(self.tmpdirname, load_corpus=True)

@classmethod
def tearDownClass(cls):
# remove the temp dir with rmtree
Expand Down

0 comments on commit db53725

Please sign in to comment.