Skip to content

Commit

Permalink
[Performance, Refactor, BugFix] Faster loading of uninitialized stora…
Browse files Browse the repository at this point in the history
…ges (pytorch#2221)
  • Loading branch information
vmoens authored Jun 11, 2024
1 parent 166467a commit 3787a9e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
36 changes: 29 additions & 7 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
RewardClipping,
RewardScaling,
SqueezeTransform,
StepCounter,
ToTensorImage,
UnsqueezeTransform,
VecNorm,
Expand Down Expand Up @@ -3114,16 +3115,20 @@ class TestCheckpointers:
"checkpointer",
[FlatStorageCheckpointer, H5StorageCheckpointer, NestedStorageCheckpointer],
)
def test_simple_env(self, storage_type, checkpointer, tmpdir):
@pytest.mark.parametrize("frames_per_batch", [22, 122])
def test_simple_env(self, storage_type, checkpointer, tmpdir, frames_per_batch):
env = GymEnv(CARTPOLE_VERSIONED(), device=None)
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(
env, policy=env.rand_step, total_frames=200, frames_per_batch=22
env,
policy=env.rand_step,
total_frames=200,
frames_per_batch=frames_per_batch,
)
rb = ReplayBuffer(storage=storage_type(100))
rb_test = ReplayBuffer(storage=storage_type(100))
if torch.__version__ < "2.4.0" and checkpointer in (
if torch.__version__ < "2.4.0.dev" and checkpointer in (
H5StorageCheckpointer,
NestedStorageCheckpointer,
):
Expand All @@ -3137,22 +3142,32 @@ def test_simple_env(self, storage_type, checkpointer, tmpdir):
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])
assert rb._writer._cursor == rb_test._writer._cursor

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("frames_per_batch", [22, 122])
@pytest.mark.parametrize(
"checkpointer",
[FlatStorageCheckpointer, NestedStorageCheckpointer, H5StorageCheckpointer],
)
def test_multi_env(self, storage_type, checkpointer, tmpdir):
env = SerialEnv(3, lambda: GymEnv(CARTPOLE_VERSIONED(), device=None))
def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch):
env = SerialEnv(
3,
lambda: GymEnv(CARTPOLE_VERSIONED(), device=None).append_transform(
StepCounter()
),
)
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(
env, policy=env.rand_step, total_frames=200, frames_per_batch=22
env,
policy=env.rand_step,
total_frames=200,
frames_per_batch=frames_per_batch,
)
rb = ReplayBuffer(storage=storage_type(100, ndim=2))
rb_test = ReplayBuffer(storage=storage_type(100, ndim=2))
if torch.__version__ < "2.4.0" and checkpointer in (
if torch.__version__ < "2.4.0.dev" and checkpointer in (
H5StorageCheckpointer,
NestedStorageCheckpointer,
):
Expand All @@ -3164,10 +3179,17 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir):
for data in collector:
rb.extend(data)
assert rb._storage.max_size == 102
if frames_per_batch > 100:
assert rb._storage._is_full
assert len(rb) == 102
# Checks that when writing to the buffer with a batch greater than the total
# size, we get the last step written properly.
assert (rb[:]["next", "step_count"][:, -1] != 0).any()
rb.dumps(tmpdir)
rb.dumps(tmpdir)
rb_test.loads(tmpdir)
assert_allclose_td(rb_test[:], rb[:])
assert rb._writer._cursor == rb_test._writer._cursor


if __name__ == "__main__":
Expand Down
13 changes: 7 additions & 6 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def _total_shape(self):
leaf = next(tree_iter(self._storage))
_total_shape = leaf.shape[: self.ndim]
self.__dict__["_total_shape_value"] = _total_shape
self._len = torch.Size([self._len_along_dim0, *_total_shape[1:]]).numel()
return _total_shape

@property
Expand All @@ -443,18 +444,18 @@ def _is_full(self):
def _len_along_dim0(self):
# returns the length of the buffer along dim0
len_along_dim = len(self)
if self.ndim:
if self.ndim > 1:
_total_shape = self._total_shape
if _total_shape is not None:
len_along_dim = len_along_dim // _total_shape[1:].numel()
len_along_dim = -(len_along_dim // -_total_shape[1:].numel())
else:
return None
return len_along_dim

def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
# returns the max_size of the buffer along dim0
max_size = self.max_size
if self.ndim:
if self.ndim > 1:
shape = self.shape
if shape is None:
if single_data is not None:
Expand All @@ -471,14 +472,14 @@ def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
break
if batched_data is not None:
datashape = datashape[1:]
max_size = max_size // datashape.numel()
max_size = -(max_size // -datashape.numel())
else:
max_size = max_size // self._total_shape[1:].numel()
max_size = -(max_size // -self._total_shape[1:].numel())
return max_size

@property
def shape(self):
# Shape, turncated where needed to accomodate for the length of the storage
# Shape, truncated where needed to accommodate for the length of the storage
if self._is_full:
return self._total_shape
_total_shape = self._total_shape
Expand Down
18 changes: 16 additions & 2 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,30 @@ def __call__(self, data: TensorDictBase, out: TensorDictBase = None):
if _storage_shape is not None and len(_storage_shape) > 1:
# iterate over data and allocate
if out is None:
# out = TensorDict(batch_size=_storage_shape)
# for i in range(out.ndim):
# if i >= 2:
# # FLattening the lazy stack will make the data unavailable - we need to find a way to make this
# # possible.
# raise RuntimeError(
# "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. "
# "Please file an issue on GitHub to ask for this feature!"
# )
# out = LazyStackedTensorDict(*out.unbind(i), stack_dim=i)
out = TensorDict(batch_size=_storage_shape)
for i in range(out.ndim):
for i in range(1, out.ndim):
if i >= 2:
# FLattening the lazy stack will make the data unavailable - we need to find a way to make this
# possible.
raise RuntimeError(
"Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. "
"Please file an issue on GitHub to ask for this feature!"
)
out = LazyStackedTensorDict(*out.unbind(i), stack_dim=i)
out_list = [
out._get_sub_tensordict((slice(None),) * i + (j,))
for j in range(out.shape[i])
]
out = LazyStackedTensorDict(*out_list, stack_dim=i)

# Create a function that reads slices of the input data
with out.flatten(1, -1) if out.ndim > 2 else contextlib.nullcontext(
Expand Down

0 comments on commit 3787a9e

Please sign in to comment.