Skip to content

Commit

Permalink
[BugFix] Fix replay buffer extension with lists (pytorch#1937)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent eacad37 commit c45ee1f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 30 deletions.
60 changes: 32 additions & 28 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,44 +156,20 @@ def _get_datum(self, datatype):

def _get_data(self, datatype, size):
if datatype is None:
data = torch.randint(
100,
(
size,
1,
),
)
data = torch.randint(100, (size, 1))
elif datatype == "tensor":
data = torch.randint(
100,
(
size,
1,
),
)
data = torch.randint(100, (size, 1))
elif datatype == "tensordict":
data = TensorDict(
{
"a": torch.randint(
100,
(
size,
1,
),
),
"a": torch.randint(100, (size, 1)),
"next": {"reward": torch.randn(size, 1)},
},
[size],
)
elif datatype == "pytree":
data = {
"a": torch.randint(
100,
(
size,
1,
),
),
"a": torch.randint(100, (size, 1)),
"b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]},
30: torch.zeros(size, 2),
}
Expand Down Expand Up @@ -838,6 +814,34 @@ def test_set_tensorclass(self, max_size, shape, storage):
tc_sample = mystorage.get(idx)
assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]])

def test_extend_list_pytree(self, max_size, shape, storage):
memory = ReplayBuffer(
storage=storage(max_size=max_size),
sampler=SamplerWithoutReplacement(),
)
data = [
(
torch.full(shape, i),
{"a": torch.full(shape, i), "b": (torch.full(shape, i))},
[torch.full(shape, i)],
)
for i in range(10)
]
memory.extend(data)
sample = memory.sample(10)
for leaf in torch.utils._pytree.tree_leaves(sample):
assert (leaf.unique(sorted=True) == torch.arange(10)).all()
memory = ReplayBuffer(
storage=storage(max_size=max_size),
sampler=SamplerWithoutReplacement(),
)
t1x4 = torch.Tensor([0.1, 0.2, 0.3, 0.4])
t1x1 = torch.Tensor([0.01])
with pytest.raises(
RuntimeError, match="Stacking the elements of the list resulted in an error"
):
memory.extend([t1x4, t1x1, t1x4 + 0.4, t1x1 + 0.01])


@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
@pytest.mark.parametrize("contiguous", [True, False])
Expand Down
22 changes: 22 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,21 @@ def set(
else:
self._len = max(self._len, max(cursor) + 1)

if isinstance(data, list):
# flip list
try:
data = _flip_list(data)
except Exception:
raise RuntimeError(
"Stacking the elements of the list resulted in "
"an error. "
f"Storages of type {type(self)} expect all elements of the list "
f"to have the same tree structure. If the list is compact (each "
f"leaf is itself a batch with the appropriate number of elements) "
f"consider using a tuple instead, as lists are used within `extend` "
f"for per-item addition."
)

if not self.initialized:
if not isinstance(cursor, INT_CLASSES):
if is_tensor_collection(data):
Expand Down Expand Up @@ -1319,3 +1334,10 @@ def save_tensor(tensor_path: str, tensor: torch.Tensor):
out.append(save_tensor(tensor_path, tensor))

return tree_unflatten(out, data_specs)


def _flip_list(data):
flat_data, flat_specs = zip(*[tree_flatten(item) for item in data])
flat_data = zip(*flat_data)
stacks = [torch.stack(item) for item in flat_data]
return tree_unflatten(stacks, flat_specs[0])
13 changes: 11 additions & 2 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
from tensordict import is_tensor_collection, MemoryMappedTensor
from tensordict.utils import _STRDTYPE2DTYPE
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten

try:
from torch.utils._pytree import tree_leaves
except ImportError:
from torch.utils._pytree import tree_flatten

def tree_leaves(data): # noqa: D103
tree_flat, _ = tree_flatten(data)
return tree_flat


from torchrl.data.replay_buffers.storages import Storage
from torchrl.data.replay_buffers.utils import _reduce
Expand Down Expand Up @@ -125,7 +134,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
elif isinstance(data, list):
batch_size = len(data)
else:
batch_size = len(tree_flatten(data)[0][0])
batch_size = len(tree_leaves(data)[0])
if batch_size == 0:
raise RuntimeError("Expected at least one element in extend.")
device = data.device if hasattr(data, "device") else None
Expand Down

0 comments on commit c45ee1f

Please sign in to comment.