-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
lru_cache_test.py
158 lines (116 loc) · 4.34 KB
/
lru_cache_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib.util
import tempfile
import time
from absl.testing import absltest
from jax._src import path as pathlib
from jax._src.lru_cache import _CACHE_SUFFIX, LRUCache
import jax._src.test_util as jtu
class LRUCacheTestCase(jtu.JaxTestCase):
name: str | None
path: pathlib.Path | None
def setUp(self):
if importlib.util.find_spec("filelock") is None:
self.skipTest("filelock is not installed")
super().setUp()
tmpdir = tempfile.TemporaryDirectory()
self.enter_context(tmpdir)
self.name = tmpdir.name
self.path = pathlib.Path(self.name)
def tearDown(self):
self.path = None
self.name = None
super().tearDown()
def assertCacheKeys(self, keys):
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), {self.path / f"{key}{_CACHE_SUFFIX}" for key in keys})
class LRUCacheTest(LRUCacheTestCase):
def test_get_nonexistent_key(self):
cache = LRUCache(self.name, max_size=-1)
self.assertIsNone(cache.get("a"))
def test_put_and_get_key(self):
cache = LRUCache(self.name, max_size=-1)
cache.put("a", b"a")
self.assertEqual(cache.get("a"), b"a")
self.assertCacheKeys(("a",))
cache.put("b", b"b")
self.assertEqual(cache.get("a"), b"a")
self.assertEqual(cache.get("b"), b"b")
self.assertCacheKeys(("a", "b"))
def test_put_empty_value(self):
cache = LRUCache(self.name, max_size=-1)
cache.put("a", b"")
self.assertEqual(cache.get("a"), b"")
def test_put_empty_key(self):
cache = LRUCache(self.name, max_size=-1)
with self.assertRaisesRegex(ValueError, r"key cannot be empty"):
cache.put("", b"a")
def test_eviction(self):
cache = LRUCache(self.name, max_size=2)
cache.put("a", b"a")
cache.put("b", b"b")
# `sleep()` is necessary to guarantee that `b`'s timestamp is strictly greater than `a`'s
time.sleep(1)
cache.get("b")
# write `c`. `a` should be evicted
cache.put("c", b"c")
self.assertCacheKeys(("b", "c"))
# calling `get()` on `b` makes `c` least recently used
time.sleep(1)
cache.get("b")
# write `d`. `c` should be evicted
cache.put("d", b"d")
self.assertCacheKeys(("b", "d"))
def test_eviction_with_empty_value(self):
cache = LRUCache(self.name, max_size=1)
cache.put("a", b"a")
# write `b` with length 0
# eviction should not happen even though the cache is full
cache.put("b", b"")
self.assertCacheKeys(("a", "b"))
# calling `get()` on `a` makes `b` least recently used
time.sleep(1)
cache.get("a")
# writing `c` should result in evicting the
# least recent used file (`b`) first,
# but this is not sufficient to make room for `c`,
# so `a` should be evicted as well
cache.put("c", b"c")
self.assertCacheKeys(("c",))
def test_existing_cache_dir(self):
cache = LRUCache(self.name, max_size=2)
cache.put("a", b"a")
# simulates reinitializing the cache in another process
del cache
cache = LRUCache(self.name, max_size=2)
self.assertEqual(cache.get("a"), b"a")
# ensure that the LRU policy survives cache reinitialization
cache.put("b", b"b")
# calling `get()` on `a` makes `b` least recently used
time.sleep(1)
cache.get("a")
# write `c`. `b` should be evicted
cache.put("c", b"c")
self.assertCacheKeys(("a", "c"))
def test_max_size(self):
cache = LRUCache(self.name, max_size=1)
msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum "
r"cache size of \d+ bytes")
with self.assertWarnsRegex(UserWarning, msg):
cache.put("a", b"aaaa")
self.assertIsNone(cache.get("a"))
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())