From 25bd8a5f1e86108f03d478a9d701bc142fb23761 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 6 Dec 2023 08:03:29 +0000 Subject: [PATCH] [Feature] pickle-free RB checkpointing (#1733) --- docs/source/reference/data.rst | 69 +++ test/test_rb.py | 424 +++++++++++++----- torchrl/data/replay_buffers/replay_buffers.py | 75 ++++ torchrl/data/replay_buffers/samplers.py | 133 +++++- torchrl/data/replay_buffers/storages.py | 94 +++- torchrl/data/replay_buffers/writers.py | 88 +++- 6 files changed, 749 insertions(+), 134 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index d41a0c0e2e4..55ebd12e867 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -87,6 +87,49 @@ write onto the storage. The following code snippet examplifies this feature: ... assert (rb["_data", "a"][:10] == 0).all() # data from main process ... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process +Sharing replay buffers across processes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Replay buffers can be shared between processes as long as their components are +sharable. This feature allows for multiple processes to collect data and populate a shared +replay buffer collaboratively, rather than centralizing the data on the main process +which can incur some data transmission overhead. + +Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` +or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage` +as long as they are instantiated and their content is stored as memory-mapped +tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter` +are currently not sharable, and the same goes for stateful samplers such as +:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. + +A shared replay buffer can be read and extended on any process that has access +to it, as the following example shows: + + >>> import pickle + >>> + >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage + >>> import torch + >>> from torch import multiprocessing as mp + >>> from tensordict import TensorDict + >>> + >>> def worker(rb): + ... td = TensorDict({"a": torch.ones(10)}, [10]) + ... # Extends the shared replay buffer on a subprocess + ... rb.extend(td) + >>> + >>> if __name__ == "__main__": + ... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) + ... td = TensorDict({"a": torch.zeros(10)}, [10]) + .. # extends the replay buffer on the main process + ... rb.extend(td) + ... + ... proc = mp.Process(target=worker, args=(rb,)) + ... proc.start() + ... proc.join() + ... # Checks that the length of the buffer equates the length of both + ... # extensions (local and remote) + ... assert len(rb) == 20 + Storing trajectories ~~~~~~~~~~~~~~~~~~~~ @@ -131,6 +174,32 @@ can be used: device=None, is_shared=False) +Checkpointing Replay Buffers +---------------------------- + +Each component of the replay buffer can potentially be stateful and, as such, +require a dedicated way of being serialized. +Our replay buffer enjoys two separate APIs for saving their state on disk: +:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the +data of each component except transforms (storage, writer, sampler) using memory-mapped +tensors and json files for the metadata. This will work across all classes except +:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content +cannot be anticipated (and as such does not comply with memory-mapped data +structures such as those that can be found in the tensordict library). +This API guarantees that a buffer that is saved and then loaded back will be in +the exact same state, whether we look at the status of its sampler (eg, priority trees) +its writer (eg, max writer heaps) or its storage. +Under the hood, :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public +`dumps` method in a specific folder for each of its components (except transforms +which we don't assume to be serializable using memory-mapped tensors in general). + +Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an +alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data +structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load` +before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback +of this method is that it will struggle to save big data structures, which is a +common setting when using replay buffers. + Datasets -------- diff --git a/test/test_rb.py b/test/test_rb.py index fe9ab157c91..f740e07e8ca 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -82,7 +82,9 @@ @pytest.mark.parametrize( "sampler", [samplers.RandomSampler, samplers.PrioritizedSampler] ) -@pytest.mark.parametrize("writer", [writers.RoundRobinWriter]) +@pytest.mark.parametrize( + "writer", [writers.RoundRobinWriter, writers.TensorDictMaxValueWriter] +) @pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage]) @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: @@ -106,7 +108,9 @@ def _get_datum(self, rb_type): elif ( rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer ): - data = TensorDict({"a": torch.randint(100, (1,))}, []) + data = TensorDict( + {"a": torch.randint(100, (1,)), "next": {"reward": torch.randn(1)}}, [] + ) else: raise NotImplementedError(rb_type) return data @@ -121,6 +125,7 @@ def _get_data(self, rb_type, size): { "a": torch.randint(100, (size,)), "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), + "next": {"reward": torch.randn(size, 1)}, }, [size], ) @@ -138,6 +143,12 @@ def test_add(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_datum(rb_type) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.add(data) + return rb.add(data) s = rb.sample(1) assert s.ndim, s @@ -155,7 +166,22 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size): writer = writer() writer.register_storage(storage) batch1 = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(batch1) and isinstance(storage, TensorStorage) + cond = ( + OLD_TORCH + and not isinstance(writer, TensorDictMaxValueWriter) + and size < len(batch1) + and isinstance(storage, TensorStorage) + ) + + if isinstance(batch1, torch.Tensor) and isinstance( + writer, TensorDictMaxValueWriter + ): + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + writer.extend(batch1) + return + with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -167,13 +193,19 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size): assert writer._cursor == 5 # Added more data than storage max size elif size < 5: - assert writer._cursor == 5 - size + # if Max writer, we don't necessarily overwrite existing values so + # we just check that the cursor is before the threshold + if isinstance(writer, TensorDictMaxValueWriter): + assert writer._cursor <= 5 - size + else: + assert writer._cursor == 5 - size # Added as data as storage max size else: assert writer._cursor == 0 - batch2 = self._get_data(rb_type, size=size - 1) - writer.extend(batch2) - assert writer._cursor == size - 1 + if not isinstance(writer, TensorDictMaxValueWriter): + batch2 = self._get_data(rb_type, size=size - 1) + writer.extend(batch2) + assert writer._cursor == size - 1 def test_extend(self, rb_type, sampler, writer, storage, size): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: @@ -185,7 +217,21 @@ def test_extend(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb._storage, TensorStorage) + ) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return + length = min(rb._storage.max_size, len(rb) + data.shape[0]) + if writer is TensorDictMaxValueWriter: + data["next", "reward"][-length:] = 1_000_000 with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -209,7 +255,10 @@ def test_extend(self, rb_type, sampler, writer, storage, size): raise RuntimeError("did not find match") data2 = self._get_data(rb_type, size=2 * size + 2) cond = ( - OLD_TORCH and size < len(data2) and isinstance(rb._storage, TensorStorage) + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data2) + and isinstance(rb._storage, TensorStorage) ) with pytest.warns( UserWarning, @@ -227,7 +276,18 @@ def test_sample(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb._storage, TensorStorage) + ) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -263,7 +323,18 @@ def test_index(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb._storage, TensorStorage) + ) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -390,11 +461,124 @@ class TC: with pytest.warns( DeprecationWarning, match="Support for Memmap device other than CPU" ): + # this is rather brittle and will fail with some indices + # when both device (storage and data) don't match (eg, range()) storage.set(0, data) else: storage.set(0, data) assert storage.get(0).device.type == device_storage.type + @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) + @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) + @pytest.mark.parametrize("init_out", [True, False]) + def test_storage_state_dict(self, storage_in, storage_out, init_out): + buffer_size = 100 + if storage_in == "memmap": + storage_in = LazyMemmapStorage(buffer_size, device="cpu") + elif storage_in == "tensor": + storage_in = LazyTensorStorage(buffer_size, device="cpu") + if storage_out == "memmap": + storage_out = LazyMemmapStorage(buffer_size, device="cpu") + elif storage_out == "tensor": + storage_out = LazyTensorStorage(buffer_size, device="cpu") + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, prefetch=3, storage=storage_in, batch_size=3 + ) + # fill replay buffer with random data + transition = TensorDict( + { + "observation": torch.ones(1, 4), + "action": torch.ones(1, 2), + "reward": torch.ones(1, 1), + "dones": torch.ones(1, 1), + "next": {"observation": torch.ones(1, 4)}, + }, + batch_size=1, + ) + for _ in range(3): + replay_buffer.extend(transition) + + state_dict = replay_buffer.state_dict() + + new_replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=3, + storage=storage_out, + batch_size=state_dict["_batch_size"], + ) + if init_out: + new_replay_buffer.extend(transition) + + new_replay_buffer.load_state_dict(state_dict) + s = new_replay_buffer.sample() + assert (s.exclude("index") == 1).all() + + @pytest.mark.parametrize("device_data", get_default_devices()) + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize("data_type", ["tensor", "tc", "td"]) + @pytest.mark.parametrize("isinit", [True, False]) + def test_storage_dumps_loads( + self, device_data, storage_type, data_type, isinit, tmpdir + ): + dir_rb = tmpdir / "rb" + dir_save = tmpdir / "save" + dir_rb.mkdir() + dir_save.mkdir() + torch.manual_seed(0) + + @tensorclass + class TC: + tensor: torch.Tensor + td: TensorDict + text: str + + if data_type == "tensor": + data = torch.randint(10, (3,), device=device_data) + elif data_type == "td": + data = TensorDict( + { + "a": torch.randint(10, (3,), device=device_data), + "b": TensorDict( + {"c": torch.randint(10, (3,), device=device_data)}, + batch_size=[3], + ), + }, + batch_size=[3], + device=device_data, + ) + elif data_type == "tc": + data = TC( + tensor=torch.randint(10, (3,), device=device_data), + td=TensorDict( + {"c": torch.randint(10, (3,), device=device_data)}, batch_size=[3] + ), + text="some text", + batch_size=[3], + device=device_data, + ) + else: + raise NotImplementedError + if storage_type in (LazyMemmapStorage,): + storage = storage_type(max_size=10, scratch_dir=dir_rb) + else: + storage = storage_type(max_size=10) + # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index + storage.set(range(3), data.cpu()) + storage.dumps(dir_save) + # check we can dump twice + storage.dumps(dir_save) + storage_recover = storage_type(max_size=10) + if isinit: + storage_recover.set(range(3), data.cpu().zero_()) + storage_recover.loads(dir_save) + if data_type == "tensor": + torch.testing.assert_close(storage._storage, storage_recover._storage) + else: + assert_allclose_td(storage._storage, storage_recover._storage) + if data == "tc": + assert storage._storage.text == storage_recover._storage.text + @pytest.mark.parametrize("max_size", [1000]) @pytest.mark.parametrize("shape", [[3, 4]]) @@ -488,7 +672,8 @@ def test_prototype_prb(priority_key, contiguous, device): "_idx": torch.arange(3).view(3, 1), }, batch_size=[3], - ).to(device) + device=device, + ) rb.extend(td1) s = rb.sample() assert s.batch_size == torch.Size([5]) @@ -503,7 +688,8 @@ def test_prototype_prb(priority_key, contiguous, device): "_idx": torch.arange(5).view(5, 1), }, batch_size=[5], - ).to(device) + device=device, + ) rb.extend(td2) s = rb.sample() assert s.batch_size == torch.Size([5]) @@ -1174,125 +1360,111 @@ def test_replay_buffer_iter(size, drop_last): assert i == (size - 1) // 3 -class TestStateDict: - @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) - @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) - @pytest.mark.parametrize("init_out", [True, False]) - def test_load_state_dict(self, storage_in, storage_out, init_out): - buffer_size = 100 - if storage_in == "memmap": - storage_in = LazyMemmapStorage(buffer_size, device="cpu") - elif storage_in == "tensor": - storage_in = LazyTensorStorage(buffer_size, device="cpu") - if storage_out == "memmap": - storage_out = LazyMemmapStorage(buffer_size, device="cpu") - elif storage_out == "tensor": - storage_out = LazyTensorStorage(buffer_size, device="cpu") - - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, prefetch=3, storage=storage_in, batch_size=3 +@pytest.mark.parametrize("size", [20, 25, 30]) +@pytest.mark.parametrize("batch_size", [1, 10, 15]) +@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) +@pytest.mark.parametrize("device", get_default_devices()) +class TestMaxValueWriter: + def test_max_value_writer(self, size, batch_size, reward_ranges, device): + torch.manual_seed(0) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), ) - # fill replay buffer with random data - transition = TensorDict( + + max_reward1, max_reward2, max_reward3 = reward_ranges + + td = TensorDict( { - "observation": torch.ones(1, 4), - "action": torch.ones(1, 2), - "reward": torch.ones(1, 1), - "dones": torch.ones(1, 1), - "next": {"observation": torch.ones(1, 4)}, + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.rand(size), }, - batch_size=1, + batch_size=size, + device=device, ) - for _ in range(3): - replay_buffer.extend(transition) - - state_dict = replay_buffer.state_dict() + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward1).all() + assert (0 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) - new_replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=3, - storage=storage_out, - batch_size=state_dict["_batch_size"], + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, ) - if init_out: - new_replay_buffer.extend(transition) - - new_replay_buffer.load_state_dict(state_dict) - s = new_replay_buffer.sample() - assert (s.exclude("index") == 1).all() - - -@pytest.mark.parametrize("size", [20, 25, 30]) -@pytest.mark.parametrize("batch_size", [1, 10, 15]) -@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_max_value_writer(size, batch_size, reward_ranges, device): - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(size, device=device), - sampler=SamplerWithoutReplacement(), - batch_size=batch_size, - writer=TensorDictMaxValueWriter(rank_key="key"), - ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward2).all() + assert (max_reward1 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) - max_reward1, max_reward2, max_reward3 = reward_ranges + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) - td = TensorDict( - { - "key": torch.clamp_max(torch.rand(size), max=max_reward1), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") <= max_reward1).all() - assert (0 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) + for sample in td: + rb.add(sample) - td = TensorDict( - { - "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") <= max_reward2).all() - assert (max_reward1 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) + sample = rb.sample() + assert (sample.get("key") <= max_reward3).all() + assert (max_reward2 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) - td = TensorDict( - { - "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) + # Finally, test the case when no obs should be added + td = TensorDict( + { + "key": torch.zeros(size), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") != 0).all() - for sample in td: - rb.add(sample) + def test_max_value_writer_serialize( + self, size, batch_size, reward_ranges, device, tmpdir + ): + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), + ) - sample = rb.sample() - assert (sample.get("key") <= max_reward3).all() - assert (max_reward2 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) + max_reward1, max_reward2, max_reward3 = reward_ranges - # Finally, test the case when no obs should be added - td = TensorDict( - { - "key": torch.zeros(size), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") != 0).all() + td = TensorDict( + { + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + rb._writer.dumps(tmpdir) + # check we can dump twice + rb._writer.dumps(tmpdir) + other = TensorDictMaxValueWriter(rank_key="key") + other.loads(tmpdir) + assert len(rb._writer._current_top_values) == len(other._current_top_values) + torch.testing.assert_close( + torch.tensor(rb._writer._current_top_values), + torch.tensor(other._current_top_values), + ) class TestMultiProc: @@ -1312,8 +1484,11 @@ def exec_multiproc_rb( storage_type=LazyMemmapStorage, init=True, writer_type=TensorDictRoundRobinWriter, + sampler_type=RandomSampler, ): - rb = TensorDictReplayBuffer(storage=storage_type(21), writer=writer_type()) + rb = TensorDictReplayBuffer( + storage=storage_type(21), writer=writer_type(), sampler=sampler_type() + ) if init: td = TensorDict( {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10] @@ -1353,9 +1528,16 @@ def test_error_nonshared(self): def test_error_maxwriter(self): # TensorDictMaxValueWriter cannot be shared - with pytest.raises(RuntimeError, match="cannot be shared between processed"): + with pytest.raises(RuntimeError, match="cannot be shared between processes"): self.exec_multiproc_rb(writer_type=TensorDictMaxValueWriter) + def test_error_prb(self): + # PrioritizedSampler cannot be shared + with pytest.raises(RuntimeError, match="cannot be shared between processes"): + self.exec_multiproc_rb( + sampler_type=lambda: PrioritizedSampler(21, alpha=1.1, beta=0.5) + ) + def test_error_noninit(self): # list storage cannot be shared with pytest.raises(RuntimeError, match="it has not been initialized yet"): diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index cfc6c90bb2c..8b7acdd9d10 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -4,9 +4,11 @@ # LICENSE file in the root directory of this source tree. import collections +import json import threading import warnings from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -230,6 +232,7 @@ def state_dict(self) -> Dict[str, Any]: "_storage": self._storage.state_dict(), "_sampler": self._sampler.state_dict(), "_writer": self._writer.state_dict(), + "_transforms": self._transform.state_dict(), "_batch_size": self._batch_size, } @@ -237,8 +240,80 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) self._sampler.load_state_dict(state_dict["_sampler"]) self._writer.load_state_dict(state_dict["_writer"]) + self._transform.load_state_dict(state_dict["_transforms"]) self._batch_size = state_dict["_batch_size"] + def dumps(self, path): + """Saves the replay buffer on disk at the specified path. + + Args: + path (Path or str): path where to save the replay buffer. + + Examples: + >>> import tempfile + >>> import tqdm + >>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + >>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler + >>> import torch + >>> from tensordict import TensorDict + >>> # Build and populate the replay buffer + >>> S = 1_000_000 + >>> sampler = PrioritizedSampler(S, 1.1, 1.0) + >>> # sampler = RandomSampler() + >>> storage = LazyMemmapStorage(S) + >>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler) + >>> + >>> for _ in tqdm.tqdm(range(100)): + ... td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100]) + ... rb.extend(td) + ... sample = rb.sample(32) + ... rb.update_tensordict_priority(sample) + >>> # save and load the buffer + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... rb.dumps(tmpdir) + ... + ... sampler = PrioritizedSampler(S, 1.1, 1.0) + ... # sampler = RandomSampler() + ... storage = LazyMemmapStorage(S) + ... rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler) + ... rb_load.loads(tmpdir) + ... assert len(rb) == len(rb_load) + + """ + path = Path(path).absolute() + path.mkdir(exist_ok=True) + self._storage.dumps(path / "storage") + self._sampler.dumps(path / "sampler") + self._writer.dumps(path / "writer") + # fall back on state_dict for transforms + transform_sd = self._transform.state_dict() + if transform_sd: + torch.save(transform_sd, path / "transform.t") + with open(path / "buffer_metadata.json", "w") as file: + json.dump({"batch_size": self._batch_size}, file) + + def loads(self, path): + """Loads a replay buffer state at the given path. + + The buffer should have matching components and be saved using :meth:`~.dumps`. + + Args: + path (Path or str): path where the replay buffer was saved. + + See :meth:`~.dumps` for more info. + + """ + path = Path(path).absolute() + self._storage.loads(path / "storage") + self._sampler.loads(path / "sampler") + self._writer.loads(path / "writer") + # fall back on state_dict for transforms + if (path / "transform.t").exists(): + self._transform.load_state_dict(torch.load(path / "transform.t")) + with open(path / "buffer_metadata.json", "r") as file: + metadata = json.load(file) + self._batch_size = metadata["batch_size"] + def add(self, data: Any) -> int: """Add a single element to the replay buffer. diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 16660aff90f..fde6ed9b69e 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -2,14 +2,19 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import json import warnings from abc import ABC, abstractmethod -from copy import deepcopy +from copy import copy, deepcopy +from multiprocessing.context import get_spawning_popen +from pathlib import Path from typing import Any, Dict, Tuple, Union import numpy as np import torch +from tensordict import MemoryMappedTensor + from ..._extension import EXTENSION_WARNING try: @@ -68,6 +73,14 @@ def ran_out(self) -> bool: def _empty(self): ... + @abstractmethod + def dumps(self, path): + ... + + @abstractmethod + def loads(self, path): + ... + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. @@ -87,6 +100,14 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] def _empty(self): pass + def dumps(self, path): + # no op + ... + + def loads(self, path): + # no op + ... + class SamplerWithoutReplacement(Sampler): """A data-consuming sampler that ensures that the same sample is not present in consecutive batches. @@ -114,6 +135,29 @@ def __init__(self, drop_last: bool = False): self.drop_last = drop_last self._ran_out = False + def dumps(self, path): + path = Path(path) + path.mkdir(exist_ok=True) + + with open(path / "sampler_metadata.json", "w") as file: + json.dump( + { + "len_storage": self.len_storage, + "_sample_list": self._sample_list, + "drop_last": self.drop_last, + "_ran_out": self._ran_out, + }, + file, + ) + + def loads(self, path): + with open(path / "sampler_metadata.json", "r") as file: + metadata = json.load(file) + self._sample_list = metadata["_sample_list"] + self.len_storage = metadata["len_storage"] + self.drop_last = metadata["drop_last"] + self._ran_out = metadata["_ran_out"] + def _single_sample(self, len_storage, batch_size): index = self._sample_list[:batch_size] self._sample_list = self._sample_list[batch_size:] @@ -208,6 +252,14 @@ def __init__( self.dtype = dtype self._init() + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Samplers of type {type(self)} cannot be shared between processes." + ) + state = copy(self.__dict__) + return state + def _init(self): if self.dtype in (torch.float, torch.FloatType, torch.float32): self._sum_tree = SumSegmentTreeFp32(self._max_capacity) @@ -276,11 +328,15 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: def add(self, index: int) -> None: super().add(index) - self._add_or_extend(index) + if index is not None: + # some writers don't systematically write data and can return None + self._add_or_extend(index) def extend(self, index: torch.Tensor) -> None: super().extend(index) - self._add_or_extend(index) + if index is not None: + # some writers don't systematically write data and can return None + self._add_or_extend(index) def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] @@ -339,3 +395,74 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._max_priority = state_dict["_max_priority"] self._sum_tree = state_dict.pop("_sum_tree") self._min_tree = state_dict.pop("_min_tree") + + def dumps(self, path): + path = Path(path).absolute() + path.mkdir(exist_ok=True) + try: + mm_st = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "sumtree.memmap", + ) + mm_mt = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "mintree.memmap", + ) + except FileNotFoundError: + mm_st = MemoryMappedTensor.empty( + (self._max_capacity,), + dtype=torch.float64, + filename=path / "sumtree.memmap", + ) + mm_mt = MemoryMappedTensor.empty( + (self._max_capacity,), + dtype=torch.float64, + filename=path / "mintree.memmap", + ) + mm_st.copy_( + torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)]) + ) + mm_mt.copy_( + torch.tensor([self._min_tree[i] for i in range(self._max_capacity)]) + ) + with open(path / "sampler_metadata.json", "w") as file: + json.dump( + { + "_alpha": self._alpha, + "_beta": self._beta, + "_eps": self._eps, + "_max_priority": self._max_priority, + "_max_capacity": self._max_capacity, + }, + file, + ) + + def loads(self, path): + path = Path(path).absolute() + with open(path / "sampler_metadata.json", "r") as file: + metadata = json.load(file) + self._alpha = metadata["_alpha"] + self._beta = metadata["_beta"] + self._eps = metadata["_eps"] + self._max_priority = metadata["_max_priority"] + _max_capacity = metadata["_max_capacity"] + if _max_capacity != self._max_capacity: + raise RuntimeError( + f"max capacity of loaded metadata ({_max_capacity}) differs from self._max_capacity ({self._max_capacity})." + ) + mm_st = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "sumtree.memmap", + ) + mm_mt = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "mintree.memmap", + ) + for i, elt in enumerate(mm_st.tolist()): + self._sum_tree[i] = elt + for i, elt in enumerate(mm_mt.tolist()): + self._min_tree[i] = elt diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 2c4ec8acc6b..4e01eeffb67 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -4,18 +4,21 @@ # LICENSE file in the root directory of this source tree. import abc +import json import os import warnings from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen +from pathlib import Path from typing import Any, Dict, Sequence, Union +import numpy as np import torch from tensordict import is_tensorclass from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase -from tensordict.utils import expand_right +from tensordict.utils import _STRDTYPE2DTYPE, expand_right from torch import multiprocessing as mp from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE @@ -54,6 +57,14 @@ def set(self, cursor: int, data: Any): def get(self, index: int) -> Any: ... + @abc.abstractmethod + def dumps(self, path): + ... + + @abc.abstractmethod + def loads(self, path): + ... + def attach(self, buffer: Any) -> None: """This function attaches a sampler to this storage. @@ -109,8 +120,23 @@ def __init__(self, max_size: int): super().__init__(max_size) self._storage = [] + def dumps(self, path): + raise NotImplementedError( + "ListStorage doesn't support serialization via `dumps` - `loads` API." + ) + + def loads(self, path): + raise NotImplementedError( + "ListStorage doesn't support serialization via `dumps` - `loads` API." + ) + def set(self, cursor: Union[int, Sequence[int], slice], data: Any): if not isinstance(cursor, INT_CLASSES): + if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( + isinstance(cursor, np.ndarray) and cursor.size <= 1 + ): + self.set(int(cursor), data) + return if isinstance(cursor, slice): self._storage[cursor] = data return @@ -269,6 +295,72 @@ def __init__(self, storage, max_size=None, device="cpu"): ) self._storage = storage + def dumps(self, path): + path = Path(path) + path.mkdir(exist_ok=True) + + if not self.initialized: + raise RuntimeError("Cannot save a non-initialized storage.") + if isinstance(self._storage, torch.Tensor): + try: + MemoryMappedTensor.from_filename( + shape=self._storage.shape, + filename=path / "storage.memmap", + dtype=self._storage.dtype, + ).copy_(self._storage) + except FileNotFoundError: + MemoryMappedTensor.from_tensor( + self._storage, filename=path / "storage.memmap", copy_existing=True + ) + is_tensor = True + dtype = str(self._storage.dtype) + shape = list(self._storage.shape) + else: + # try to load the path and overwrite. + try: + saved = TensorDict.load_memmap(path) + except FileNotFoundError: + # otherwise create a new one + saved = self._storage.memmap_like(path) + saved.update_(self._storage) + is_tensor = False + dtype = None + shape = None + + with open(path / "storage_metadata.json", "w") as file: + json.dump( + { + "is_tensor": is_tensor, + "dtype": dtype, + "shape": shape, + "len": self._len, + }, + file, + ) + + def loads(self, path): + with open(path / "storage_metadata.json", "r") as file: + metadata = json.load(file) + is_tensor = metadata["is_tensor"] + shape = metadata["shape"] + dtype = metadata["dtype"] + _len = metadata["len"] + if dtype is not None: + shape = torch.Size(shape) + dtype = _STRDTYPE2DTYPE[dtype] + if is_tensor: + _storage = MemoryMappedTensor.from_filename( + path / "storage.memmap", shape=shape, dtype=dtype + ).clone() + else: + _storage = TensorDict.load_memmap(path) + if not self.initialized: + self._storage = _storage + self.initialized = True + else: + self._storage.copy_(_storage) + self._len = _len + @property def _len(self): _len_value = self.__dict__.get("_len_value", None) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index f171fb2a9ff..702898b5292 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -4,13 +4,17 @@ # LICENSE file in the root directory of this source tree. import heapq +import json from abc import ABC, abstractmethod from copy import copy from multiprocessing.context import get_spawning_popen +from pathlib import Path from typing import Any, Dict, Sequence import numpy as np import torch +from tensordict import is_tensor_collection, MemoryMappedTensor +from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp from .storages import Storage @@ -39,6 +43,14 @@ def extend(self, data: Sequence) -> torch.Tensor: def _empty(self): ... + @abstractmethod + def dumps(self, path): + ... + + @abstractmethod + def loads(self, path): + ... + def state_dict(self) -> Dict[str, Any]: return {} @@ -53,6 +65,18 @@ def __init__(self, **kw) -> None: super().__init__(**kw) self._cursor = 0 + def dumps(self, path): + path = Path(path).absolute() + path.mkdir(exist_ok=True) + with open(path / "metadata.json", "w") as file: + json.dump({"cursor": self._cursor}, file) + + def loads(self, path): + path = Path(path).absolute() + with open(path / "metadata.json", "r") as file: + metadata = json.load(file) + self._cursor = metadata["cursor"] + def add(self, data: Any) -> int: ret = self._cursor _cursor = self._cursor @@ -181,6 +205,10 @@ def __init__(self, rank_key=None, **kwargs) -> None: def get_insert_index(self, data: Any) -> int: """Returns the index where the data should be inserted, or ``None`` if it should not be inserted.""" + if not is_tensor_collection(data): + raise RuntimeError( + f"{type(self)} expects data to be a tensor collection (tensordict or tensorclass). Found a {type(data)} instead." + ) if data.batch_dims > 1: raise RuntimeError( "Expected input tensordict to have no more than 1 dimension, got" @@ -188,7 +216,7 @@ def get_insert_index(self, data: Any) -> int: ) ret = None - rank_data = data.get(("_data", self._rank_key)) + rank_data = data.get("_data", default=data).get(self._rank_key) # If time dimension, sum along it. rank_data = rank_data.sum(-1).item() @@ -198,7 +226,6 @@ def get_insert_index(self, data: Any) -> int: # If the buffer is not full, add the data if len(self._current_top_values) < self._storage.max_size: - ret = self._cursor self._cursor = (self._cursor + 1) % self._storage.max_size @@ -246,13 +273,17 @@ def extend(self, data: Sequence) -> None: # Replace the data in the storage all at once if len(data_to_replace) > 0: keys, values = zip(*data_to_replace.items()) - index = data.get("index") + index = data.get("index", None) + dtype = index.dtype if index is not None else torch.long + device = index.device if index is not None else data.device values = list(values) - keys = index[values] = torch.tensor( - keys, dtype=index.dtype, device=index.device - ) - data.set("index", index) - self._storage[keys] = data[values] + keys = torch.tensor(keys, dtype=dtype, device=device) + if index is not None: + index[values] = keys + data.set("index", index) + self._storage.set(keys, data[values]) + return keys.long() + return None def _empty(self) -> None: self._cursor = 0 @@ -261,7 +292,46 @@ def _empty(self) -> None: def __getstate__(self): if get_spawning_popen() is not None: raise RuntimeError( - f"Writers of type {type(self)} cannot be shared between processed." + f"Writers of type {type(self)} cannot be shared between processes." ) state = copy(self.__dict__) return state + + def dumps(self, path): + path = Path(path).absolute() + path.mkdir(exist_ok=True) + t = torch.tensor(self._current_top_values) + try: + MemoryMappedTensor.from_filename( + filename=path / "current_top_values.memmap", + shape=t.shape, + dtype=t.dtype, + ).copy_(t) + except FileNotFoundError: + MemoryMappedTensor.from_tensor( + t, filename=path / "current_top_values.memmap" + ) + with open(path / "metadata.json", "w") as file: + json.dump( + { + "cursor": self._cursor, + "rank_key": self._rank_key, + "dtype": str(t.dtype), + "shape": list(t.shape), + }, + file, + ) + + def loads(self, path): + path = Path(path).absolute() + with open(path / "metadata.json", "r") as file: + metadata = json.load(file) + self._cursor = metadata["cursor"] + self._rank_key = metadata["rank_key"] + shape = torch.Size(metadata["shape"]) + dtype = metadata["dtype"] + self._current_top_values = MemoryMappedTensor.from_filename( + filename=path / "current_top_values.memmap", + dtype=_STRDTYPE2DTYPE[dtype], + shape=shape, + ).tolist()