Skip to content

Commit

Permalink
[Feature] Better repr of RBs (pytorch#1991)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 4, 2024
1 parent 6fb16a2 commit 5ad2436
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 10 deletions.
32 changes: 31 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@


@pytest.mark.parametrize(
"sampler", [samplers.RandomSampler, samplers.PrioritizedSampler]
"sampler",
[
samplers.RandomSampler,
samplers.SamplerWithoutReplacement,
samplers.PrioritizedSampler,
],
)
@pytest.mark.parametrize(
"writer", [writers.RoundRobinWriter, writers.TensorDictMaxValueWriter]
Expand Down Expand Up @@ -184,6 +189,28 @@ def _get_data(self, datatype, size):
raise NotImplementedError(datatype)
return data

def test_rb_repr(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
"Distributed package support on Windows is a prototype feature and is subject to changes."
)
torch.manual_seed(0)
rb = self._get_rb(
rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size
)
data = self._get_datum(datatype)
if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter:
with pytest.raises(
RuntimeError, match="expects data to be a tensor collection"
):
rb.add(data)
return
rb.add(data)
# we just check that str runs, not its value
assert str(rb)
rb.sample()
assert str(rb)

def test_add(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down Expand Up @@ -2588,7 +2615,9 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1)

rb = rbtype(storage=storage_cls(100, ndim=datadim), batch_size=4)
assert str(rb) # check str works
rb.extend(data)
assert str(rb)
assert len(rb) == 12
data = rb[:]
if datatype in ("tensordict", "tensorclass"):
Expand All @@ -2598,6 +2627,7 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
leaf.shape[:datadim].numel() == 12 for leaf in tree_flatten(data)[0]
)
s = rb.sample()
assert str(rb)
if datatype in ("tensordict", "tensorclass"):
assert (s.exclude("index") == 1).all()
assert s.numel() == 4
Expand Down
51 changes: 44 additions & 7 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,21 @@ def __len__(self) -> int:
return len(self._storage)

def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"storage={self._storage}, "
f"sampler={self._sampler}, "
f"writer={self._writer}"
")"
)
from torchrl.envs.transforms import Compose

storage = textwrap.indent(f"storage={self._storage}", " " * 4)
writer = textwrap.indent(f"writer={self._writer}", " " * 4)
sampler = textwrap.indent(f"sampler={self._sampler}", " " * 4)
if self._transform is not None and not (
isinstance(self._transform, Compose) and not len(self._transform)
):
transform = textwrap.indent(f"transform={self._transform}", " " * 4)
transform = f"\n{self._transform}, "
else:
transform = ""
batch_size = textwrap.indent(f"batch_size={self._batch_size}", " " * 4)
collate_fn = textwrap.indent(f"collate_fn={self._collate_fn}", " " * 4)
return f"{self.__class__.__name__}(\n{storage}, \n{sampler}, \n{writer}, {transform}\n{batch_size}, \n{collate_fn})"

@pin_memory_output
def __getitem__(self, index: int | torch.Tensor | NestedKey) -> Any:
Expand Down Expand Up @@ -657,6 +665,33 @@ def __setstate__(self, state: Dict[str, Any]):
state["_futures_lock"] = _futures_lock
self.__dict__.update(state)

@property
def sampler(self):
"""The sampler of the replay buffer.
The sampler must be an instance of :class:`~torchrl.data.replay_buffers.Sampler`.
"""
return self._sampler

@property
def writer(self):
"""The writer of the replay buffer.
The writer must be an instance of :class:`~torchrl.data.replay_buffers.Writer`.
"""
return self._writer

@property
def storage(self):
"""The storage of the replay buffer.
The storage must be an instance of :class:`~torchrl.data.replay_buffers.Storage`.
"""
return self._storage


class PrioritizedReplayBuffer(ReplayBuffer):
"""Prioritized replay buffer.
Expand Down Expand Up @@ -1475,6 +1510,8 @@ class ReplayBufferEnsemble(ReplayBuffer):
"""

_collate_fn_val = None

def __init__(
self,
*rbs,
Expand Down
57 changes: 56 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def dumps(self, path):
def loads(self, path):
...

def __repr__(self):
return f"{self.__class__.__name__}()"


class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers.
Expand Down Expand Up @@ -247,6 +250,13 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.drop_last = state_dict["drop_last"]
self._ran_out = state_dict["_ran_out"]

def __repr__(self):
if self._sample_list is not None:
perc = len(self._sample_list) / self.len_storage * 100
else:
perc = 0.0
return f"{self.__class__.__name__}({perc: 4.4f}% sampled)"


class PrioritizedSampler(Sampler):
"""Prioritized sampler for replay buffer.
Expand Down Expand Up @@ -335,6 +345,9 @@ def __init__(
self.dtype = dtype
self._init()

def __repr__(self):
return f"{self.__class__.__name__}(alpha={self._alpha}, beta={self._beta}, eps={self._eps}, reduction={self.reduction})"

@property
def max_size(self):
return self._max_capacity
Expand Down Expand Up @@ -769,6 +782,16 @@ def __init__(
f"Got num_slices={num_slices} and slice_len={slice_len}."
)

def __repr__(self):
return (
f"{self.__class__.__name__}(num_slices={self.num_slices}, "
f"slice_len={self.slice_len}, "
f"end_key={self.end_key}, "
f"traj_key={self.traj_key}, "
f"truncated_key={self.truncated_key}, "
f"strict_length={self.strict_length})"
)

@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
if trajectory is not None:
Expand Down Expand Up @@ -1246,6 +1269,19 @@ def __init__(
)
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)

def __repr__(self):
perc = len(self._sample_list) / self.len_storage * 100
return (
f"{self.__class__.__name__}("
f"num_slices={self.num_slices}, "
f"slice_len={self.slice_len}, "
f"end_key={self.end_key}, "
f"traj_key={self.traj_key}, "
f"truncated_key={self.truncated_key}, "
f"strict_length={self.strict_length},"
f"{perc}% sampled)"
)

def _empty(self):
self._cache = {}
SamplerWithoutReplacement._empty(self)
Expand Down Expand Up @@ -1422,6 +1458,25 @@ def __init__(
reduction=reduction,
)

def __repr__(self):
if self._sample_list is not None:
perc = len(self._sample_list) / self.len_storage * 100
else:
perc = 0.0
return (
f"{self.__class__.__name__}("
f"num_slices={self.num_slices}, "
f"slice_len={self.slice_len}, "
f"end_key={self.end_key}, "
f"traj_key={self.traj_key}, "
f"truncated_key={self.truncated_key}, "
f"strict_length={self.strict_length},"
f"alpha={self._alpha}, "
f"beta={self._beta}, "
f"eps={self._eps},"
f"{perc: 4.4f}% filled)"
)

def __getstate__(self):
state = SliceSampler.__getstate__(self)
state.update(PrioritizedSampler.__getstate__(self))
Expand Down Expand Up @@ -1778,4 +1833,4 @@ def __len__(self):

def __repr__(self):
samplers = textwrap.indent(f"samplers={self._samplers}", " " * 4)
return f"SamplerEnsemble(\n{samplers})"
return f"{self.__class__.__name__}(\n{samplers})"
37 changes: 36 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class Storage:
def __init__(self, max_size: int) -> None:
self.max_size = int(max_size)

@property
def _is_full(self):
return len(self) == self.max_size

@property
def _attached_entities(self):
# RBs that use a given instance of Storage should add
Expand Down Expand Up @@ -282,6 +286,9 @@ def __getstate__(self):
state = copy(self.__dict__)
return state

def __repr__(self):
return f"{self.__class__.__name__}(items=[{self._storage[0]}, ...])"


class TensorStorage(Storage):
"""A storage for tensors and tensordicts.
Expand Down Expand Up @@ -502,7 +509,11 @@ def _len_along_dim0(self):
# returns the length of the buffer along dim0
len_along_dim = len(self)
if self.ndim:
len_along_dim = len_along_dim // self._total_shape[1:].numel()
_total_shape = self._total_shape
if _total_shape is not None:
len_along_dim = len_along_dim // _total_shape[1:].numel()
else:
return None
return len_along_dim

def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
Expand Down Expand Up @@ -548,6 +559,8 @@ def _rand_given_ndim(self, batch_size):
def flatten(self):
if self.ndim == 1:
return self
if not self.initialized:
raise RuntimeError("Cannot flatten a non-initialized storage.")
if is_tensor_collection(self._storage):
if self._is_full:
return TensorStorage(self._storage.flatten(0, self.ndim - 1))
Expand Down Expand Up @@ -766,6 +779,8 @@ def set( # noqa: F811
def get(self, index: Union[int, Sequence[int], slice]) -> Any:
_storage = self._storage
is_tc = is_tensor_collection(_storage)
if not self.initialized:
raise RuntimeError("Cannot get elements out of a non-initialized storage.")
if not self._is_full:
if is_tc:
storage = self._storage[: self._len_along_dim0]
Expand Down Expand Up @@ -795,6 +810,26 @@ def _init(self):
f"{type(self)} must be initialized during construction."
)

def __repr__(self):
if not self.initialized:
storage_str = textwrap.indent("data=<empty>", 4 * " ")
elif is_tensor_collection(self._storage):
storage_str = textwrap.indent(f"data={self[:]}", 4 * " ")
else:

def repr_item(x):
if isinstance(x, torch.Tensor):
return f"{x.__class__.__name__}(shape={x.shape}, dtype={x.dtype}, device={x.device})"
return x.__class__.__name__

storage_str = textwrap.indent(
f"data={tree_map(repr_item, self[:])}", 4 * " "
)
shape_str = textwrap.indent(f"shape={self.shape}", 4 * " ")
len_str = textwrap.indent(f"len={len(self)}", 4 * " ")
maxsize_str = textwrap.indent(f"max_size={self.max_size}", 4 * " ")
return f"{self.__class__.__name__}(\n{storage_str}, \n{shape_str}, \n{len_str}, \n{maxsize_str})"


class LazyTensorStorage(TensorStorage):
"""A pre-allocated tensor storage for tensors and tensordicts.
Expand Down
11 changes: 11 additions & 0 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def tree_leaves(data): # noqa: D103
class Writer(ABC):
"""A ReplayBuffer base Writer class."""

_storage: Storage

def __init__(self) -> None:
self._storage = None

Expand Down Expand Up @@ -92,6 +94,9 @@ def _replicate_index(self, index):
1,
)

def __repr__(self):
return f"{self.__class__.__name__}()"


class ImmutableDatasetWriter(Writer):
"""A blocking writer for immutable datasets."""
Expand Down Expand Up @@ -210,6 +215,9 @@ def __setstate__(self, state):
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)

def __repr__(self):
return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full})"


class TensorDictRoundRobinWriter(RoundRobinWriter):
"""A RoundRobin Writer class for composable, tensordict-based replay buffers."""
Expand Down Expand Up @@ -521,6 +529,9 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
raise NotImplementedError

def __repr__(self):
return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full}, rank_key={self._rank_key}, reduction={self._reduction})"


class WriterEnsemble(Writer):
"""An ensemble of writers.
Expand Down

0 comments on commit 5ad2436

Please sign in to comment.