Skip to content

Commit

Permalink
[BugFix] Extend with a list of tensordicts (pytorch#2032)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 21, 2024
1 parent 660d827 commit a69c667
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,12 @@ class TC:
if data == "tc":
assert storage._storage.text == storage_recover._storage.text

def test_add_list_of_tds(self):
rb = ReplayBuffer(storage=LazyTensorStorage(100))
rb.extend([TensorDict({"a": torch.randn(2, 3)}, [2])])
assert len(rb) == 1
assert rb[:].shape == torch.Size([1, 2])


@pytest.mark.parametrize("max_size", [1000])
@pytest.mark.parametrize("shape", [[3, 4]])
Expand Down
2 changes: 2 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,8 @@ def save_tensor(tensor_path: str, tensor: torch.Tensor):


def _flip_list(data):
if all(is_tensor_collection(_data) for _data in data):
return torch.stack(data)
flat_data, flat_specs = zip(*[tree_flatten(item) for item in data])
flat_data = zip(*flat_data)
stacks = [torch.stack(item) for item in flat_data]
Expand Down

0 comments on commit a69c667

Please sign in to comment.