Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Replay buffer checkpointers #2137

Merged
merged 14 commits into from
May 16, 2024
Prev Previous commit
Next Next commit
t
  • Loading branch information
vmoens committed May 7, 2024
commit 4e86a09791cfa00dfca173eeac8c41cd355d1ea4
10 changes: 6 additions & 4 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
UnsqueezeTransform,
VecNorm,
)

from torchrl.data import FlatStorageCheckpointer, NestedStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, TensorStorageCheckpointer, ListStorageCheckpointer
OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
_has_gym = importlib.util.find_spec("gym") is not None
Expand Down Expand Up @@ -2904,7 +2904,7 @@ def test_done_slicesampler(self, strict_length):


@pytest.mark.skipif(not _has_gym, reason="gym required")
class TestSaveHooks:
class TestCheckpointers:
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_simple_env(self, storage_type, tmpdir):
env = GymEnv(CARTPOLE_VERSIONED())
Expand All @@ -2915,8 +2915,10 @@ def test_simple_env(self, storage_type, tmpdir):
)
rb = ReplayBuffer(storage=storage_type(1000))
rb_test = ReplayBuffer(storage=storage_type(1000))
rb.register_save_hook(TED2Flat())
rb_test.register_load_hook(Flat2TED())
# rb.storage.checkpointer = FlatStorageCheckpointer()
# rb_test.storage.checkpointer = FlatStorageCheckpointer()
rb.storage.checkpointer = NestedStorageCheckpointer()
rb_test.storage.checkpointer = NestedStorageCheckpointer()
for i, data in enumerate(collector):
rb.extend(data)
if i == 0:
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@
UnboundedDiscreteTensorSpec,
)
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
from .replay_buffers import FlatStorageCheckpointer, NestedStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, TensorStorageCheckpointer, ListStorageCheckpointer
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
Writer,
WriterEnsemble,
)
from .checkpointers import FlatStorageCheckpointer, NestedStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, TensorStorageCheckpointer, ListStorageCheckpointer
270 changes: 270 additions & 0 deletions torchrl/data/replay_buffers/checkpointers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import abc
import json
from pathlib import Path

import torch
from tensordict import is_tensor_collection, TensorDict, PersistentTensorDict
from tensordict.memmap import MemoryMappedTensor
from tensordict.utils import _STRDTYPE2DTYPE
from torchrl._utils import implement_for
from torchrl.data.replay_buffers.utils import (
_get_paths,Nested2TED,
_init_pytree,
_init_pytree_common,
_path2str,
_save_pytree,
_save_pytree_common,
Flat2TED,
TED2Flat,
TED2Nested,H5Split,H5Combine,
)


class StorageCheckpointerBase:
"""Public base class for storage checkpointers.

Each storage checkpointer must implement a `save` and `load` method that take as input a storage and a
path.

"""

@abc.abstractmethod
def dumps(self, storage, path):
...

@abc.abstractmethod
def loads(self, storage, path):
...


class ListStorageCheckpointer(StorageCheckpointerBase):
"""A storage checkpointer for ListStoage.

Currently not implemented.

"""

@staticmethod
def dumps(storage, path):
raise NotImplementedError(
"ListStorage doesn't support serialization via `dumps` - `loads` API."
)

@staticmethod
def loads(storage, path):
raise NotImplementedError(
"ListStorage doesn't support serialization via `dumps` - `loads` API."
)


class TensorStorageCheckpointer(StorageCheckpointerBase):
"""A storage checkpointer for TensorStorages.

This class supports TensorDict-based storages as well as pytrees.

"""

_save_hooks = []
_load_hooks = []

def dumps(self, storage, path):
path = Path(path)
path.mkdir(exist_ok=True)

if not storage.initialized:
raise RuntimeError("Cannot save a non-initialized storage.")
metadata = {}
_storage = storage._storage
for hook in self._save_hooks:
_storage = hook(_storage)
if is_tensor_collection(_storage):
# try to load the path and overwrite.
_storage.memmap(
path, copy_existing=True, num_threads=torch.get_num_threads()
)
is_pytree = False
else:
_save_pytree(_storage, metadata, path)
is_pytree = True

with open(path / "storage_metadata.json", "w") as file:
json.dump(
{
"metadata": metadata,
"is_pytree": is_pytree,
"len": storage._len,
},
file,
)

def loads(self, storage, path):
with open(path / "storage_metadata.json", "r") as file:
metadata = json.load(file)
is_pytree = metadata["is_pytree"]
_len = metadata["len"]
if is_pytree:
for hook in self._load_hooks:
raise RuntimeError(
"Loading hooks are not compatible with PyTree storages."
)
path = Path(path)
for local_path, md in metadata["metadata"].items():
# load tensor
local_path_dot = local_path.replace(".", "/")
total_tensor_path = path / (local_path_dot + ".memmap")
shape = torch.Size(md["shape"])
dtype = _STRDTYPE2DTYPE[md["dtype"]]
tensor = MemoryMappedTensor.from_filename(
filename=total_tensor_path, shape=shape, dtype=dtype
)
# split path
local_path = local_path.split(".")
# replace potential dots
local_path = [_path.replace("_<dot>_", ".") for _path in local_path]
if storage.initialized:
# copy in-place
_storage_tensor = storage._storage
# in this case there is a single tensor, so we skip
if local_path != ["_-single-tensor-_"]:
for _path in local_path:
if _path.isdigit():
_path_attempt = int(_path)
try:
_storage_tensor = _storage_tensor[_path_attempt]
continue
except IndexError:
pass
_storage_tensor = _storage_tensor[_path]
_storage_tensor.copy_(tensor)
else:
raise RuntimeError(
"Cannot fill a non-initialized pytree-based TensorStorage."
)
else:
_storage = TensorDict.load_memmap(path)
for hook in self._load_hooks:
_storage = hook(_storage)
if not storage.initialized:
# this should not be reached if is_pytree=True
storage._init(_storage[0])
storage._storage.update_(_storage)
else:
storage._storage.copy_(_storage)
storage._len = _len


class FlatStorageCheckpointer(TensorStorageCheckpointer):
"""Saves the storage in a compact form, saving space on the TED format.

This class explicitly assumes and does NOT check that:

- done states (including terminated and truncated) at the root are always False;
vmoens marked this conversation as resolved.
Show resolved Hide resolved
- observations in the "next" tensordict are shifted by one step in the future (this
is not the case when a multi-step transform is used for instance).

"""

def __init__(self):
self._save_hooks = [TED2Flat()]
self._load_hooks = [Flat2TED()]


class NestedStorageCheckpointer(TensorStorageCheckpointer):
"""Saves the storage in a compact form, saving space on the TED format and using memory-mapped nested tensors.

This class explicitly assumes and does NOT check that:

- done states (including terminated and truncated) at the root are always False;
- observations in the "next" tensordict are shifted by one step in the future (this
is not the case when a multi-step transform is used for instance).

"""

def __init__(self):
self._save_hooks = [TED2Nested()]
self._load_hooks = [Nested2TED()]

class H5StorageCheckpointer(TensorStorageCheckpointer):
"""Saves the storage in a compact form, saving space on the TED format and using H5 format to save the data.

This class explicitly assumes and does NOT check that:

- done states (including terminated and truncated) at the root are always False;
- observations in the "next" tensordict are shifted by one step in the future (this
is not the case when a multi-step transform is used for instance).

Args:
checkpoint_file: TODO
**kwargs: kwargs to be passed to :meth:`h5py.File.create_dataset`.

"""

def __init__(self, checkpoint_file: str="checkpoint.h5", **kwargs):
self._save_hooks = [TED2Nested(), H5Split()]
self._load_hooks = [H5Combine(), Nested2TED()]
self.kwargs = kwargs
self.checkpoint_file = checkpoint_file

def dumps(self, storage, path):
path = Path(path)
path.mkdir(exist_ok=True)
path = path / self.checkpoint_file

if not storage.initialized:
raise RuntimeError("Cannot save a non-initialized storage.")
metadata = {}
_storage = storage._storage
length = len(storage)
for hook in self._save_hooks:
_storage = hook(_storage)
if is_tensor_collection(_storage):
# try to load the path and overwrite.
data = PersistentTensorDict.from_dict(_storage, path, **self.kwargs)
data["_len"] = length
else:
raise ValueError

def loads(self, storage, path):
path = Path(path)
path = path / self.checkpoint_file
data = PersistentTensorDict.from_h5(path)
_len = data["_len"]
for hook in self._load_hooks:
data = hook(data)
if not storage.initialized:
# this should not be reached if is_pytree=True
storage._init(data[0])
storage._storage.update_(data)
else:
storage._storage.copy_(data)
# TODO
storage._len = _len


class StorageEnsembleCheckpointer(StorageCheckpointerBase):
"""Checkpointer for ensemble storages."""

@staticmethod
def dumps(storage, path: Path):
path = Path(path).absolute()
storages = storage._storages
for i, storage in enumerate(storages):
storage.dumps(path / str(i))
if storage._transforms is not None:
for i, transform in enumerate(storage._transforms):
torch.save(transform.state_dict(), path / f"{i}_transform.pt")

@staticmethod
def loads(storage, path: Path):
path = Path(path).absolute()
for i, storage in enumerate(storage._storages):
storage.loads(path / str(i))
if storage._transforms is not None:
for i, transform in enumerate(storage._transforms):
transform.load_state_dict(torch.load(path / f"{i}_transform.pt"))
10 changes: 6 additions & 4 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __init__(
transform: "Transform" | None = None, # noqa-F821
batch_size: int | None = None,
dim_extend: int | None = None,
checkpointer: StorageCheckpointerBase | None = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
Expand Down Expand Up @@ -260,6 +261,7 @@ def __init__(
if dim_extend is not None and dim_extend < 0:
raise ValueError("dim_extend must be a positive value.")
self.dim_extend = dim_extend
self._storage.checkpointer = checkpointer

@property
def dim_extend(self):
Expand Down Expand Up @@ -1021,12 +1023,12 @@ class TensorDictReplayBuffer(ReplayBuffer):

"""

def __init__(self, *, priority_key: str = "td_error", **kw) -> None:
writer = kw.get("writer", None)
def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None:
writer = kwargs.get("writer", None)
if writer is None:
kw["writer"] = TensorDictRoundRobinWriter()
kwargs["writer"] = TensorDictRoundRobinWriter()

super().__init__(**kw)
super().__init__(**kwargs)
self.priority_key = priority_key

def _get_priority_item(self, tensordict: TensorDictBase) -> float:
Expand Down
Loading