diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 27e91b84e29..6b74834385e 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -19,6 +19,7 @@ ) from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE + from torchrl.data.replay_buffers.utils import ( _save_pytree, Flat2TED, @@ -93,6 +94,7 @@ def dumps(self, storage, path): if is_tensor_collection(_storage): if ( _storage.is_memmap() + and _storage.saved_path and Path(_storage.saved_path).absolute() == Path(path).absolute() ): _storage.memmap_refresh_() @@ -170,9 +172,28 @@ def loads(self, storage, path): for hook in self._load_hooks: _storage = hook(_storage, out=dest) if not storage.initialized: - # this should not be reached if is_pytree=True - storage._init(_storage[0]) - storage._storage.update_(_storage) + from torchrl.data.replay_buffers.storages import LazyMemmapStorage + + if ( + isinstance(storage, LazyMemmapStorage) + and storage.scratch_dir + and Path(storage.scratch_dir).absolute() == Path(path).absolute() + ): + storage._storage = TensorDict.load_memmap(path) + storage.initialized = True + else: + # this should not be reached if is_pytree=True + storage._init(_storage[0]) + storage._storage.update_(_storage) + elif ( + storage._storage.is_memmap() + and storage._storage.saved_path + and Path(storage._storage.saved_path).absolute() + == Path(path).absolute() + ): + # If the storage is already where it should be, we don't need to load anything. + storage._storage.memmap_refresh_() + else: storage._storage.copy_(_storage) storage._len = _len diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 3c540c7ff3e..58b1729296d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -917,6 +917,23 @@ class LazyMemmapStorage(LazyTensorStorage): has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + .. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is + already stored to avoid executing long copies of data that is already stored on disk. + This will only work if the default :class:`~torchrl.data.TensorStorageCheckpointer` checkpointer is used. + Example: + >>> from tensordict import TensorDict + >>> from torchrl.data import TensorStorage, LazyMemmapStorage, ReplayBuffer + >>> import tempfile + >>> from pathlib import Path + >>> import time + >>> td = TensorDict(a=0, b=1).expand(1000).clone() + >>> # We pass a path that is /storage to LazyMemmapStorage + >>> rb_memmap = ReplayBuffer(storage=LazyMemmapStorage(10_000_000, scratch_dir="dump/storage")) + >>> rb_memmap.extend(td); + >>> # Checkpointing in `dump` is a zero-copy, as the data is already in `dump/storage` + >>> rb_memmap.dumps(Path("./dump")) + + Examples: >>> data = TensorDict({ ... "some data": torch.randn(10, 11),