Skip to content

Commit

Permalink
[Refactor] Buffers tensorclass compat and tutorial (#1101)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Nigam <rohitnigam@meta.com>
Co-authored-by: Rohit Nigam <rohitnigam@gmail.com>
  • Loading branch information
3 people authored May 5, 2023
1 parent aad6684 commit 39fe662
Show file tree
Hide file tree
Showing 14 changed files with 878 additions and 86 deletions.
14 changes: 14 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pytest
import torch
import torch.cuda

from tensordict import tensorclass
from torchrl._utils import implement_for, seed_generator

from torchrl.envs import ObservationNorm
Expand Down Expand Up @@ -295,3 +297,15 @@ def t_out():
)

return t_out


def make_tc(td):
"""Makes a tensorclass from a tensordict instance."""

class MyClass:
pass

MyClass.__annotations__ = {}
for key in td.keys():
MyClass.__annotations__[key] = torch.Tensor
return tensorclass(MyClass)
59 changes: 46 additions & 13 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import pytest
import torch
from _utils_internal import get_available_devices
from _utils_internal import get_available_devices, make_tc
from tensordict import is_tensorclass, tensorclass
from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase
from torchrl.data import (
Expand Down Expand Up @@ -129,9 +129,14 @@ def test_add(self, rb_type, sampler, writer, storage, size):
)
data = self._get_datum(rb_type)
rb.add(data)
s = rb._storage[0]
s = rb.sample(1)
assert s.ndim, s
s = s[0]
if isinstance(s, TensorDictBase):
assert (s == data.select(*s.keys())).all()
s = s.select(*data.keys(True), strict=False)
data = data.select(*s.keys(True), strict=False)
assert (s == data).all()
assert list(s.keys(True, True))
else:
assert (s == data).all()

Expand Down Expand Up @@ -373,14 +378,22 @@ def test_prototype_prb(priority_key, contiguous, device):


@pytest.mark.parametrize("stack", [False, True])
@pytest.mark.parametrize("datatype", ["tc", "tb"])
@pytest.mark.parametrize("reduction", ["min", "max", "median", "mean"])
def test_replay_buffer_trajectories(stack, reduction):
def test_replay_buffer_trajectories(stack, reduction, datatype):
traj_td = TensorDict(
{"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)},
batch_size=[3, 4],
)
if datatype == "tc":
c = make_tc(traj_td)
traj_td = c(**traj_td, batch_size=traj_td.batch_size)
assert is_tensorclass(traj_td)
elif datatype != "tb":
raise NotImplementedError

if stack:
traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0)
traj_td = torch.stack(list(traj_td), 0)

rb = TensorDictReplayBuffer(
sampler=samplers.PrioritizedSampler(
Expand All @@ -394,6 +407,10 @@ def test_replay_buffer_trajectories(stack, reduction):
)
rb.extend(traj_td)
sampled_td = rb.sample()
if datatype == "tc":
assert is_tensorclass(traj_td)
return

sampled_td.set("td_error", torch.rand(sampled_td.shape))
rb.update_tensordict_priority(sampled_td)
sampled_td = rb.sample(include_info=True)
Expand Down Expand Up @@ -510,9 +527,12 @@ def test_add(self, rbtype, storage, size, prefetch):
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
data = self._get_datum(rbtype)
rb.add(data)
s = rb._storage[0]
s = rb.sample(1)[0]
if isinstance(s, TensorDictBase):
assert (s == data.select(*s.keys())).all()
s = s.select(*data.keys(True), strict=False)
data = data.select(*s.keys(True), strict=False)
assert (s == data).all()
assert list(s.keys(True, True))
else:
assert (s == data).all()

Expand Down Expand Up @@ -649,6 +669,7 @@ def test_prb(priority_key, contiguous, device):
},
batch_size=[3],
).to(device)

rb.extend(td1)
s = rb.sample()
assert s.batch_size == torch.Size([5])
Expand Down Expand Up @@ -838,17 +859,29 @@ def test_insert_transform():

@pytest.mark.parametrize("transform", transforms)
def test_smoke_replay_buffer_transform(transform):
rb = ReplayBuffer(transform=transform(in_keys="observation"), batch_size=1)
rb = TensorDictReplayBuffer(
transform=transform(in_keys=["observation"]), batch_size=1
)

# td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, [])
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 3)}, [])
rb.add(td)
rb.sample()

rb._transform = mock.MagicMock()
rb._transform.__len__ = lambda *args: 3
m = mock.Mock()
m.side_effect = [td.unsqueeze(0)]
rb._transform.forward = m
# rb._transform.__len__ = lambda *args: 3
rb.sample()
assert rb._transform.called
assert rb._transform.forward.called

# was_called = [False]
# forward = rb._transform.forward
# def new_forward(*args, **kwargs):
# was_called[0] = True
# return forward(*args, **kwargs)
# rb._transform.forward = new_forward
# rb.sample()
# assert was_called[0]


transforms = [
Expand Down
11 changes: 9 additions & 2 deletions test/test_rb_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import sys
import time
Expand Down Expand Up @@ -53,7 +53,9 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl
_, inserted = _add_random_tensor_dict_to_buffer(buffer)
sampled = _sample_from_buffer(buffer, 1)
assert type(sampled) is type(inserted) is TensorDict
assert (sampled["a"] == inserted["a"]).all()
a_sample = sampled["a"]
a_insert = inserted["a"]
assert (a_sample == a_insert).all()


@pytest.mark.skipif(
Expand Down Expand Up @@ -131,3 +133,8 @@ def _sample_from_buffer(buffer, batch_size):
return rpc.rpc_sync(
buffer.owner(), ReplayBufferNode.sample, args=(buffer, batch_size)
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 2 additions & 2 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type):
trainer._process_batch_hook(td)
td_out = trainer._process_optim_batch_hook(td)
if prioritized:
td_out.set(replay_buffer.priority_key, torch.rand(N))
td_out.unlock_().set(replay_buffer.priority_key, torch.rand(N))
trainer._post_loss_hook(td_out)

trainer2 = mocking_trainer()
Expand Down Expand Up @@ -424,7 +424,7 @@ def make_storage():
# sample from rb
td_out = trainer._process_optim_batch_hook(td)
if prioritized:
td_out.set(replay_buffer.priority_key, torch.rand(N))
td_out.unlock_().set(replay_buffer.priority_key, torch.rand(N))
trainer._post_loss_hook(td_out)
trainer.save_trainer(True)

Expand Down
15 changes: 8 additions & 7 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,14 @@ def context_decorator(ctx, func):
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
if callable(ctx) and hasattr(ctx, "__enter__"):
raise RuntimeError(
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)

if not callable(ctx):

Expand Down
Loading

0 comments on commit 39fe662

Please sign in to comment.