Skip to content

Commit

Permalink
[BugFix] Better dumps/loads (pytorch#2343)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 31, 2024
1 parent da89826 commit c1093b7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
27 changes: 24 additions & 3 deletions torchrl/data/replay_buffers/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from tensordict.memmap import MemoryMappedTensor
from tensordict.utils import _STRDTYPE2DTYPE

from torchrl.data.replay_buffers.utils import (
_save_pytree,
Flat2TED,
Expand Down Expand Up @@ -93,6 +94,7 @@ def dumps(self, storage, path):
if is_tensor_collection(_storage):
if (
_storage.is_memmap()
and _storage.saved_path
and Path(_storage.saved_path).absolute() == Path(path).absolute()
):
_storage.memmap_refresh_()
Expand Down Expand Up @@ -170,9 +172,28 @@ def loads(self, storage, path):
for hook in self._load_hooks:
_storage = hook(_storage, out=dest)
if not storage.initialized:
# this should not be reached if is_pytree=True
storage._init(_storage[0])
storage._storage.update_(_storage)
from torchrl.data.replay_buffers.storages import LazyMemmapStorage

if (
isinstance(storage, LazyMemmapStorage)
and storage.scratch_dir
and Path(storage.scratch_dir).absolute() == Path(path).absolute()
):
storage._storage = TensorDict.load_memmap(path)
storage.initialized = True
else:
# this should not be reached if is_pytree=True
storage._init(_storage[0])
storage._storage.update_(_storage)
elif (
storage._storage.is_memmap()
and storage._storage.saved_path
and Path(storage._storage.saved_path).absolute()
== Path(path).absolute()
):
# If the storage is already where it should be, we don't need to load anything.
storage._storage.memmap_refresh_()

else:
storage._storage.copy_(_storage)
storage._len = _len
Expand Down
17 changes: 17 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,23 @@ class LazyMemmapStorage(LazyTensorStorage):
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
Defaults to ``1``.
.. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is
already stored to avoid executing long copies of data that is already stored on disk.
This will only work if the default :class:`~torchrl.data.TensorStorageCheckpointer` checkpointer is used.
Example:
>>> from tensordict import TensorDict
>>> from torchrl.data import TensorStorage, LazyMemmapStorage, ReplayBuffer
>>> import tempfile
>>> from pathlib import Path
>>> import time
>>> td = TensorDict(a=0, b=1).expand(1000).clone()
>>> # We pass a path that is <main_ckpt_dir>/storage to LazyMemmapStorage
>>> rb_memmap = ReplayBuffer(storage=LazyMemmapStorage(10_000_000, scratch_dir="dump/storage"))
>>> rb_memmap.extend(td);
>>> # Checkpointing in `dump` is a zero-copy, as the data is already in `dump/storage`
>>> rb_memmap.dumps(Path("./dump"))
Examples:
>>> data = TensorDict({
... "some data": torch.randn(10, 11),
Expand Down

0 comments on commit c1093b7

Please sign in to comment.