Skip to content

Commit

Permalink
[Feature] Default collate_fn (pytorch#688)
Browse files Browse the repository at this point in the history
* init

* amend
  • Loading branch information
vmoens authored Nov 19, 2022
1 parent f5d98af commit 40c04ef
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 54 deletions.
44 changes: 18 additions & 26 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from torchrl.data.replay_buffers.writers import RoundRobinWriter


collate_fn_dict = {
ListStorage: lambda x: torch.stack(x, 0),
LazyTensorStorage: lambda x: x,
LazyMemmapStorage: lambda x: x,
None: lambda x: torch.stack(x, 0),
}
# collate_fn_dict = {
# ListStorage: lambda x: torch.stack(x, 0),
# LazyTensorStorage: lambda x: x,
# LazyMemmapStorage: lambda x: x,
# None: lambda x: torch.stack(x, 0),
# }


@pytest.mark.parametrize(
Expand All @@ -54,7 +54,6 @@
@pytest.mark.parametrize("size", [3, 100])
class TestPrototypeBuffers:
def _get_rb(self, rb_type, size, sampler, writer, storage):
collate_fn = collate_fn_dict[storage]

if storage is not None:
storage = storage(size)
Expand All @@ -65,9 +64,7 @@ def _get_rb(self, rb_type, size, sampler, writer, storage):

sampler = sampler(**sampler_args)
writer = writer()
rb = rb_type(
collate_fn=collate_fn, storage=storage, sampler=sampler, writer=writer
)
rb = rb_type(storage=storage, sampler=sampler, writer=writer)
return rb

def _get_datum(self, rb_type):
Expand Down Expand Up @@ -192,7 +189,6 @@ def test_prototype_prb(priority_key, contiguous, device):
np.random.seed(0)
rb = rb_prototype.TensorDictReplayBuffer(
sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9),
collate_fn=None if contiguous else lambda x: torch.stack(x, 0),
priority_key=priority_key,
)
td1 = TensorDict(
Expand Down Expand Up @@ -271,7 +267,6 @@ def test_rb_prototype_trajectories(stack):
alpha=0.7,
beta=0.9,
),
collate_fn=lambda x: torch.stack(x, 0),
priority_key="td_error",
)
rb.extend(traj_td)
Expand Down Expand Up @@ -315,7 +310,6 @@ class TestBuffers:
_default_params_td_prb = {"alpha": 0.8, "beta": 0.9}

def _get_rb(self, rbtype, size, storage, prefetch):
collate_fn = collate_fn_dict[storage]
if storage is not None:
storage = storage(size)
if rbtype is ReplayBuffer:
Expand All @@ -328,13 +322,7 @@ def _get_rb(self, rbtype, size, storage, prefetch):
params = self._default_params_td_prb
else:
raise NotImplementedError(rbtype)
rb = rbtype(
size=size,
storage=storage,
prefetch=prefetch,
collate_fn=collate_fn,
**params
)
rb = rbtype(size=size, storage=storage, prefetch=prefetch, **params)
return rb

def _get_datum(self, rbtype):
Expand Down Expand Up @@ -460,7 +448,6 @@ def test_prb(priority_key, contiguous, device):
5,
alpha=0.7,
beta=0.9,
collate_fn=None if contiguous else lambda x: torch.stack(x, 0),
priority_key=priority_key,
)
td1 = TensorDict(
Expand Down Expand Up @@ -537,7 +524,6 @@ def test_rb_trajectories(stack):
5,
alpha=0.7,
beta=0.9,
collate_fn=lambda x: torch.stack(x, 0),
priority_key="td_error",
)
rb.extend(traj_td)
Expand Down Expand Up @@ -565,10 +551,14 @@ def test_shared_storage_prioritized_sampler():
sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1)

rb0 = rb_prototype.ReplayBuffer(
storage=storage, writer=writer, sampler=sampler0, collate_fn=lambda x: x
storage=storage,
writer=writer,
sampler=sampler0,
)
rb1 = rb_prototype.ReplayBuffer(
storage=storage, writer=writer, sampler=sampler1, collate_fn=lambda x: x
storage=storage,
writer=writer,
sampler=sampler1,
)

data = TensorDict({"a": torch.arange(50)}, [50])
Expand All @@ -593,9 +583,11 @@ def test_legacy_rb_does_not_attach():
storage = LazyMemmapStorage(n)
writer = RoundRobinWriter()
sampler = RandomSampler()
rb = ReplayBuffer(storage=storage, size=n, prefetch=0, collate_fn=lambda x: x)
rb = ReplayBuffer(storage=storage, size=n, prefetch=0)
prb = rb_prototype.ReplayBuffer(
storage=storage, writer=writer, sampler=sampler, collate_fn=lambda x: x
storage=storage,
writer=writer,
sampler=sampler,
)

assert len(storage._attached_entities) == 1
Expand Down
10 changes: 6 additions & 4 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,20 +249,22 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type):
S = 100
if storage_type == "list":
storage = ListStorage(S)
collate_fn = lambda x: torch.stack(x, 0)
elif storage_type == "memmap":
storage = LazyMemmapStorage(S)
collate_fn = lambda x: x
else:
raise NotImplementedError

if prioritized:
replay_buffer = TensorDictPrioritizedReplayBuffer(
S, 1.1, 0.9, storage=storage, collate_fn=collate_fn
S,
1.1,
0.9,
storage=storage,
)
else:
replay_buffer = TensorDictReplayBuffer(
S, storage=storage, collate_fn=collate_fn
S,
storage=storage,
)

N = 9
Expand Down
16 changes: 7 additions & 9 deletions torchrl/data/replay_buffers/rb_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from tensordict.tensordict import TensorDictBase, LazyStackedTensorDict

from .replay_buffers import pin_memory_output, stack_tensors, stack_td
from .replay_buffers import pin_memory_output
from .samplers import Sampler, RandomSampler
from .storages import Storage, ListStorage
from .storages import Storage, ListStorage, _get_default_collate
from .utils import INT_CLASSES, _to_numpy, accept_remote_rref_udf_invocation
from .writers import Writer, RoundRobinWriter

Expand Down Expand Up @@ -47,7 +47,11 @@ def __init__(
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer.register_storage(self._storage)

self._collate_fn = collate_fn or stack_tensors
self._collate_fn = (
collate_fn
if collate_fn is not None
else _get_default_collate(self._storage)
)
self._pin_memory = pin_memory

self._prefetch = bool(prefetch)
Expand Down Expand Up @@ -169,12 +173,6 @@ class TensorDictReplayBuffer(ReplayBuffer):
"""

def __init__(self, priority_key: str = "td_error", **kw) -> None:
if not kw.get("collate_fn"):

def collate_fn(x):
return stack_td(x, 0, contiguous=True)

kw["collate_fn"] = collate_fn
super().__init__(**kw)
self.priority_key = priority_key

Expand Down
25 changes: 10 additions & 15 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
from tensordict.tensordict import (
TensorDictBase,
_stack as stack_td,
LazyStackedTensorDict,
)
from torch import Tensor
Expand All @@ -24,7 +23,11 @@
SumSegmentTreeFp32,
SumSegmentTreeFp64,
)
from torchrl.data.replay_buffers.storages import Storage, ListStorage
from torchrl.data.replay_buffers.storages import (
Storage,
ListStorage,
_get_default_collate,
)
from torchrl.data.replay_buffers.utils import INT_CLASSES
from torchrl.data.replay_buffers.utils import (
_to_numpy,
Expand Down Expand Up @@ -118,9 +121,11 @@ def __init__(
self._storage = storage
self._capacity = size
self._cursor = 0
if collate_fn is None:
collate_fn = stack_tensors
self._collate_fn = collate_fn
self._collate_fn = (
collate_fn
if collate_fn is not None
else _get_default_collate(self._storage)
)
self._pin_memory = pin_memory

self._prefetch = prefetch is not None and prefetch > 0
Expand Down Expand Up @@ -558,11 +563,6 @@ def __init__(
prefetch: Optional[int] = None,
storage: Optional[Storage] = None,
):
if collate_fn is None:

def collate_fn(x):
return stack_td(x, 0, contiguous=True)

super().__init__(size, collate_fn, pin_memory, prefetch, storage=storage)


Expand Down Expand Up @@ -606,11 +606,6 @@ def __init__(
prefetch: Optional[int] = None,
storage: Optional[Storage] = None,
) -> None:
if collate_fn is None:

def collate_fn(x):
return stack_td(x, 0, contiguous=True)

super(TensorDictPrioritizedReplayBuffer, self).__init__(
size=size,
alpha=alpha,
Expand Down
27 changes: 27 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,30 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
)
elif _CKPT_BACKEND == "torch":
return mem_map_tensor._tensor


def _collate_list_tensordict(x):
out = torch.stack(x, 0)
if isinstance(out, TensorDictBase):
return out.to_tensordict()
return out


def _collate_list_tensors(*x):
return tuple(torch.stack(_x, 0) for _x in zip(*x))


def _collate_contiguous(x):
if isinstance(x, TensorDictBase):
return x.to_tensordict()
return x.clone()


def _get_default_collate(storage, _is_tensordict=True):
if isinstance(storage, ListStorage):
if _is_tensordict:
return _collate_list_tensordict
else:
return _collate_list_tensors
elif isinstance(storage, (LazyTensorStorage, LazyMemmapStorage)):
return _collate_contiguous

0 comments on commit 40c04ef

Please sign in to comment.