Skip to content

Commit

Permalink
Replace ujson with orjson, add load and close for jsonlcorpus (#49)
Browse files Browse the repository at this point in the history
* Update json to use orjson -> ujson -> json in this order

* change core requirement from orjson to ujson

* Fix use of json.load

* Update jsonlcorpus to have a load/close methods

* improve test to test for close and load, also fix how close() and load() work

* Remove test cases (moved to tests),

* remoe orjson_loaded flag

* Fix use of json again

* Update tests and add tests for comparing json and orjson

* Remove usage of ujson

* Finally fix usage of json with orjson

* Final fix (hopefully)

* Update tests
  • Loading branch information
xhluca authored Sep 3, 2024
1 parent 072d242 commit 2b97cc5
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 61 deletions.
18 changes: 8 additions & 10 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
import os
import logging
from pathlib import Path
import json
from typing import Any, Tuple, Dict, Iterable, List, NamedTuple, Union

import numpy as np

try:
import ujson as json
except ImportError:
import json
from .utils import json_functions as json_functions

try:
from .numba import selection as selection_jit
Expand Down Expand Up @@ -743,7 +741,7 @@ def save(
vocab_path = save_dir / vocab_name

with open(vocab_path, "w") as f:
json.dump(self.vocab_dict, f)
f.write(json_functions.dumps(self.vocab_dict))

# Save the parameters
params_path = save_dir / params_name
Expand Down Expand Up @@ -784,11 +782,11 @@ def save(
continue

try:
doc = json.dumps(doc)
doc_str = json_functions.dumps(doc)
except Exception as e:
logging.warning(f"Error saving document at index {i}: {e}")
else:
f.write(doc + "\n")
f.write(doc_str + "\n")

# also save corpus.mmindex
mmidx = utils.corpus.find_newline_positions(save_dir / corpus_name)
Expand Down Expand Up @@ -861,12 +859,12 @@ def load(
# Load the parameters
params_path = save_dir / params_name
with open(params_path, "r") as f:
params: dict = json.load(f)
params: dict = json_functions.loads(f.read())

# Load the vocab dictionary
vocab_path = save_dir / vocab_name
with open(vocab_path, "r") as f:
vocab_dict = json.load(f)
vocab_dict: dict = json_functions.loads(f.read())

# Load the score arrays
data_path = save_dir / data_name
Expand Down Expand Up @@ -903,7 +901,7 @@ def load(
corpus = []
with open(corpus_file, "r") as f:
for line in f:
doc = json.loads(line)
doc = json_functions.loads(line)
corpus.append(doc)

bm25_obj.corpus = corpus
Expand Down
2 changes: 1 addition & 1 deletion bm25s/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import benchmark, beir, corpus
from . import benchmark, beir, corpus, json_functions
16 changes: 5 additions & 11 deletions bm25s/utils/beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
def tqdm(iterable, *args, **kwargs):
return iterable

from . import json_functions

BASE_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"

Expand Down Expand Up @@ -38,11 +39,6 @@ def postprocess_results_for_eval(results, scores, query_ids):


def merge_cqa_dupstack(data_path):
try:
import ujson
except ImportError:
import json as ujson

data_path = Path(data_path)
dataset = data_path.name
assert dataset == "cqadupstack", "Dataset must be CQADupStack"
Expand All @@ -62,12 +58,11 @@ def merge_cqa_dupstack(data_path):
for line in tqdm(
f2, desc=f"Merging {corpus_name} Corpus", leave=False
):
# first, read with ujson
line = ujson.loads(line)
line = json_functions.loads(line)
# add the corpus name to _id
line["_id"] = f"{corpus_name}_{line['_id']}"
# write back to file
f.write(ujson.dumps(line))
f.write(json_functions.dumps(line))
f.write("\n")

# now, do the same for queries.jsonl
Expand All @@ -83,12 +78,11 @@ def merge_cqa_dupstack(data_path):
for line in tqdm(
f2, desc=f"Merging {corpus_name} Queries", leave=False
):
# first, read with ujson
line = ujson.loads(line)
line = json_functions.loads(line)
# add the corpus name to _id
line["_id"] = f"{corpus_name}_{line['_id']}"
# write back to file
f.write(ujson.dumps(line))
f.write(json_functions.dumps(line))
f.write("\n")

# now, do the same for qrels/test.tsv
Expand Down
81 changes: 45 additions & 36 deletions bm25s/utils/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

try:
import ujson as json
import orjson as json
except ImportError:
import json

Expand All @@ -17,6 +17,7 @@
def tqdm(iterable=None, *args, **kwargs):
return iterable

from . import json_functions

def change_extension(path, new_extension):
path = str(path)
Expand Down Expand Up @@ -54,14 +55,14 @@ def save_mmindex(indexes, path):
path = str(path)
index_file = change_extension(path, ".mmindex.json")
with open(index_file, "w") as f:
f.write(json.dumps(indexes))
f.write(json_functions.dumps(indexes))


def load_mmindex(path):
path = str(path)
index_file = change_extension(path, ".mmindex.json")
with open(index_file, "r") as f:
return json.loads(f.read())
return json_functions.loads(f.read())


# now we can jump to any line in the file thanks to the index and mmap
Expand All @@ -72,7 +73,7 @@ def get_line(
encoding="utf-8",
file_obj=None,
mmap_obj=None,
):
) -> str:
path = str(path)
if file_obj is None:
file_obj = open(path, "r")
Expand Down Expand Up @@ -126,8 +127,10 @@ class JsonlCorpus:
Which only loads the line you need into memory, and is much faster.
"""

def __init__(self, path, show_progress=True, leave_progress=True, save_index=True):
def __init__(self, path, show_progress=True, leave_progress=True, save_index=True, verbosity=1):
self.path = path
self.verbosity = verbosity

# if the index file does not exist, create it
if os.path.exists(change_extension(path, ".mmindex.json")):
self.mmindex = load_mmindex(path)
Expand All @@ -141,17 +144,16 @@ def __init__(self, path, show_progress=True, leave_progress=True, save_index=Tru

self.mmindex = mmindex

self.file_obj = open(path, "r")
self.mmap_obj = mmap.mmap(self.file_obj.fileno(), 0, access=mmap.ACCESS_READ)
logging.info("Opened file and mmap objects")
# Finally, open the file and mmap objects
self.load()

def __len__(self):
return len(self.mmindex)

def __getitem__(self, index):
# handle multiple indices
if isinstance(index, int):
return json.loads(
return json_functions.loads(
get_line(
self.path,
index,
Expand All @@ -176,32 +178,39 @@ def __getitem__(self, index):

raise TypeError("Invalid index type")

def __del__(self):
if hasattr(self, "file_obj"):
def close(self):
"""
Close the file and mmap objects. This is useful if you want to free up memory. To reopen them, use the `load` method.
If you don't call this method, the objects will be closed automatically when the object is deleted.
"""
if hasattr(self, "file_obj") and self.file_obj is not None:
self.file_obj.close()
if hasattr(self, "mmap_obj"):
# delete the object
del self.file_obj
self.file_obj = None
if hasattr(self, "mmap_obj") and self.mmap_obj is not None:
self.mmap_obj.close()
logging.info("Closed file and mmap objects")


if __name__ == "__main__":
# let's test the functions
# random jsonl file
file = "file.jsonl"
# content is random uuids
import uuid

with open(file, "w") as f:
for i in range(500):
f.write(json.dumps({"uuid": str(uuid.uuid4())}) + "\n")

# create the index
mmindex = find_newline_positions(file)
save_mmindex(mmindex, file)

# read the first line
# load the index
mmindex = load_mmindex(file)
print(get_line(file, 1, mmindex))

print(get_line(file, 5, mmindex))
# delete the object
del self.mmap_obj
self.mmap_obj = None
if self.verbosity >= 1:
logging.info("Closed file and mmap objects")

def load(self):
"""
Load the file and mmap objects. This is useful if you closed them and want to reopen them.
Note
----
This is called automatically when the object is created. You don't need to call it manually.
Also, if there is an existing file and mmap object, this will close them before reopening.
"""
self.close() # close any existing file and mmap objects

self.file_obj = open(self.path, "r")
self.mmap_obj = mmap.mmap(self.file_obj.fileno(), 0, access=mmap.ACCESS_READ)
if self.verbosity >= 1:
logging.info("Opened file and mmap objects")

def __del__(self):
self.close()
22 changes: 22 additions & 0 deletions bm25s/utils/json_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json

try:
import orjson
ORJSON_AVAILABLE = True
except ImportError:
ORJSON_AVAILABLE = False


def dumps_with_builtin(d: dict) -> str:
return json.dumps(d)

def dumps_with_orjson(d: dict) -> str:
return orjson.dumps(d).decode('utf-8')

if ORJSON_AVAILABLE:
dumps = dumps_with_orjson
loads = orjson.loads
else:
dumps = dumps_with_builtin
loads = json.loads

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
long_description = fp.read()

extras_require = {
"core": ["jax[cpu]", "ujson", "tqdm", "PyStemmer"],
"core": ["jax[cpu]", "orjson", "tqdm", "PyStemmer"],
"stem": ["PyStemmer"],
"hf": ["huggingface_hub"],
"dev": ["black"],
Expand Down
22 changes: 21 additions & 1 deletion tests/comparison/test_utils_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,24 @@ def test_utils_corpus(self):
doc = json.loads(line)
corpus_ids_2.append(doc["_id"])

self.assertListEqual(corpus_ids, corpus_ids_2)
self.assertListEqual(corpus_ids, corpus_ids_2)

# check if jsonl corpus can be closed
assert nq.file_obj is not None, "JsonlCorpus file_obj is None, expected file object"
assert nq.mmap_obj is not None, "JsonlCorpus mmap_obj is None, expected mmap object"

# now, we can close
nq.close()

assert nq.file_obj is None, "JsonlCorpus file_obj is not None, expected None"
assert nq.mmap_obj is None, "JsonlCorpus mmap_obj is not None, expected None"

# check if jsonl corpus can be loaded
nq.load()

assert nq.file_obj is not None, "JsonlCorpus file_obj is None, expected file object"
assert nq.mmap_obj is not None, "JsonlCorpus mmap_obj is None, expected mmap object"

corpus_ids = [doc["_id"] for doc in tqdm(nq)]
self.assertListEqual(corpus_ids, corpus_ids_2)

33 changes: 32 additions & 1 deletion tests/core/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@
from pathlib import Path
import unittest
import tempfile
import bm25s
import Stemmer # optional: for stemming
import unittest.mock
import json

import bm25s
from bm25s.utils import json_functions

class TestBM25SLoadingSaving(unittest.TestCase):
orjson_should_not_be_installed = False
orjson_should_be_installed = True

@classmethod
def setUpClass(cls):
# check that import orjson fails
import bm25s

# Create your corpus here
corpus = [
Expand All @@ -35,6 +44,13 @@ def setUpClass(cls):
cls.stemmer = stemmer
cls.tmpdirname = tempfile.mkdtemp()

def setUp(self):
# verify that orjson is properly installed
try:
import orjson
except ImportError:
self.fail("orjson should be installed to run this test.")

def test_a_save(self):
# save the retriever to temp dir
self.retriever.save(
Expand Down Expand Up @@ -87,6 +103,21 @@ def test_b_load(self):

# nnoc is stored in self.nnoc
self.assertTrue((r1.nonoccurrence_array == r2.nonoccurrence_array).all())

@unittest.mock.patch("bm25s.utils.json_functions.dumps", json_functions.dumps_with_builtin)
@unittest.mock.patch("bm25s.utils.json_functions.loads", json.loads)
def test_c_save_no_orjson(self):
self.assertEqual(json_functions.dumps_with_builtin, json_functions.dumps)
self.assertEqual(json_functions.loads, json.loads)
self.test_a_save()

@unittest.mock.patch("bm25s.utils.json_functions.dumps", json_functions.dumps_with_builtin)
@unittest.mock.patch("bm25s.utils.json_functions.loads", json.loads)
def test_d_load_no_orjson(self):
self.assertEqual(json_functions.dumps_with_builtin, json_functions.dumps)
self.assertEqual(json_functions.loads, json.loads)
self.test_b_load()


@classmethod
def tearDownClass(cls):
Expand Down
Loading

0 comments on commit 2b97cc5

Please sign in to comment.