Skip to content

Commit

Permalink
[BugFix] Fix nested CompositeSpec creation (pytorch#1261)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 13, 2023
1 parent 2de3159 commit d7ca44c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 71 deletions.
187 changes: 117 additions & 70 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,26 +368,35 @@ def test_multi_discrete_conversion(ns, shape, device):
@pytest.mark.parametrize("is_complete", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
@pytest.mark.parametrize("shape", [(), (2, 3)])
class TestComposite:
@staticmethod
def _composite_spec(is_complete=True, device=None, dtype=None):
def _composite_spec(shape, is_complete=True, device=None, dtype=None):
torch.manual_seed(0)
np.random.seed(0)

return CompositeSpec(
obs=BoundedTensorSpec(
torch.zeros(3, 32, 32),
torch.ones(3, 32, 32),
torch.zeros(*shape, 3, 32, 32),
torch.ones(*shape, 3, 32, 32),
dtype=dtype,
device=device,
),
act=UnboundedContinuousTensorSpec((7,), dtype=dtype, device=device)
act=UnboundedContinuousTensorSpec(
(
*shape,
7,
),
dtype=dtype,
device=device,
)
if is_complete
else None,
shape=shape,
)

def test_getitem(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_getitem(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
assert isinstance(ts["obs"], BoundedTensorSpec)
if is_complete:
assert isinstance(ts["act"], UnboundedContinuousTensorSpec)
Expand All @@ -396,35 +405,39 @@ def test_getitem(self, is_complete, device, dtype):
with pytest.raises(KeyError):
_ = ts["UNK"]

def test_setitem_forbidden_keys(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
for key in {"shape", "device", "dtype", "space"}:
with pytest.raises(AttributeError, match="cannot be set"):
ts[key] = 42

@pytest.mark.parametrize("dest", get_available_devices())
def test_setitem_matches_device(self, is_complete, device, dtype, dest):
ts = self._composite_spec(is_complete, device, dtype)
def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest):
ts = self._composite_spec(shape, is_complete, device, dtype)

if dest == device:
ts["good"] = UnboundedContinuousTensorSpec(device=dest, dtype=dtype)
ts["good"] = UnboundedContinuousTensorSpec(
shape=shape, device=dest, dtype=dtype
)
assert ts["good"].device == dest
else:
with pytest.raises(
RuntimeError, match="All devices of CompositeSpec must match"
):
ts["bad"] = UnboundedContinuousTensorSpec(device=dest, dtype=dtype)
ts["bad"] = UnboundedContinuousTensorSpec(
shape=shape, device=dest, dtype=dtype
)

def test_del(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_del(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
assert "obs" in ts.keys()
assert "act" in ts.keys()
del ts["obs"]
assert "obs" not in ts.keys()
assert "act" in ts.keys()

def test_encode(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_encode(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
if dtype is None:
dtype = torch.get_default_dtype()

Expand All @@ -441,24 +454,25 @@ def test_encode(self, is_complete, device, dtype):
assert encoded_vals["act"].dtype == dtype
assert (encoded_vals["act"] == r["act"]).all()

def test_is_in(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_is_in(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
for _ in range(100):
r = ts.rand()
assert ts.is_in(r)

def test_to_numpy(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_to_numpy(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
for _ in range(100):
r = ts.rand()
for key, value in ts.to_numpy(r).items():
spec = ts[key]
assert (spec.to_numpy(r[key]) == value).all()

@pytest.mark.parametrize("shape", [[], [3]])
def test_project(self, is_complete, device, dtype, shape):
ts = self._composite_spec(is_complete, device, dtype)
@pytest.mark.parametrize("shape_other", [[], [5]])
def test_project(self, shape, is_complete, device, dtype, shape_other):
ts = self._composite_spec(shape, is_complete, device, dtype)
# Using normal distribution to get out of bounds
shape = (*shape_other, *shape)
tensors = {"obs": torch.randn(*shape, 3, 32, 32, dtype=dtype, device=device)}
if is_complete:
tensors["act"] = torch.randn(*shape, 7, dtype=dtype, device=device)
Expand All @@ -469,36 +483,36 @@ def test_project(self, is_complete, device, dtype, shape):
assert ts.is_in(out_of_bounds_td)
assert out_of_bounds_td.shape == torch.Size(shape)

@pytest.mark.parametrize("shape", [[], [3]])
def test_rand(self, is_complete, device, dtype, shape):
ts = self._composite_spec(is_complete, device, dtype)
@pytest.mark.parametrize("shape_other", [[], [3]])
def test_rand(self, shape, is_complete, device, dtype, shape_other):
ts = self._composite_spec(shape, is_complete, device, dtype)
if dtype is None:
dtype = torch.get_default_dtype()

rand_td = ts.rand(shape)
shape = (*shape_other, *shape)
rand_td = ts.rand(shape_other)
assert rand_td.shape == torch.Size(shape)
assert rand_td.get("obs").shape == torch.Size([*shape, 3, 32, 32])
assert rand_td.get("obs").dtype == dtype
if is_complete:
assert rand_td.get("act").shape == torch.Size([*shape, 7])
assert rand_td.get("act").dtype == dtype

def test_repr(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_repr(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
output = repr(ts)
assert output.startswith("CompositeSpec")
assert "obs: " in output
assert "act: " in output

def test_device_cast_with_dtype_fails(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_device_cast_with_dtype_fails(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
with pytest.raises(ValueError, match="Only device casting is allowed"):
ts.to(torch.float16)

@pytest.mark.parametrize("dest", get_available_devices())
def test_device_cast(self, is_complete, device, dtype, dest):
def test_device_cast(self, shape, is_complete, device, dtype, dest):
# Note: trivial test in case there is only one device available.
ts = self._composite_spec(is_complete, device, dtype)
ts = self._composite_spec(shape, is_complete, device, dtype)
ts.rand()
td_to = ts.to(dest)
cast_r = td_to.rand()
Expand All @@ -508,17 +522,17 @@ def test_device_cast(self, is_complete, device, dtype, dest):
if is_complete:
assert cast_r["act"].device == dest

def test_type_check(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
def test_type_check(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
rand_td = ts.rand()
ts.type_check(rand_td)
ts.type_check(rand_td["obs"], "obs")
if is_complete:
ts.type_check(rand_td["act"], "act")

def test_nested_composite_spec(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
assert set(ts.keys()) == {
"obs",
"act",
Expand Down Expand Up @@ -548,49 +562,60 @@ def test_nested_composite_spec(self, is_complete, device, dtype):
if key != "nested_cp":
assert key in td["nested_cp"].keys()

def test_nested_composite_spec_index(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec_index(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(
shape, is_complete, device, dtype
)
assert ts["nested_cp"]["nested_cp"] is ts["nested_cp", "nested_cp"]
assert (
ts["nested_cp"]["nested_cp"]["obs"] is ts["nested_cp", "nested_cp", "obs"]
)

def test_nested_composite_spec_rand(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec_rand(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(
shape, is_complete, device, dtype
)
r = ts.rand()
assert (r["nested_cp", "nested_cp", "obs"] >= 0).all()

def test_nested_composite_spec_zero(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec_zero(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(
shape, is_complete, device, dtype
)
r = ts.zero()
assert (r["nested_cp", "nested_cp", "obs"] == 0).all()

def test_nested_composite_spec_setitem(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec_setitem(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(
shape, is_complete, device, dtype
)
ts["nested_cp", "nested_cp", "obs"] = None
assert (
ts["nested_cp"]["nested_cp"]["obs"] is ts["nested_cp", "nested_cp", "obs"]
)
assert ts["nested_cp"]["nested_cp"]["obs"] is None
ts["nested_cp", "another", "obs"] = None

def test_nested_composite_spec_delitem(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec_delitem(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"]["nested_cp"] = self._composite_spec(
shape, is_complete, device, dtype
)
del ts["nested_cp", "nested_cp", "obs"]
assert ("nested_cp", "nested_cp", "obs") not in ts.keys(True, True)

def test_nested_composite_spec_update(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
td2 = CompositeSpec(new=None)
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
Expand All @@ -602,8 +627,8 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
"new",
}

ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device))
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
Expand All @@ -615,8 +640,8 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
("nested_cp", "new"),
}

ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device))
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
Expand All @@ -628,12 +653,18 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
}
assert ts["nested_cp"]["act"] is None

ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device))
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
td2 = CompositeSpec(
nested_cp=CompositeSpec(act=None, shape=shape).to(device), shape=shape
)
ts.update(td2)
td2 = CompositeSpec(
nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device=device))
nested_cp=CompositeSpec(
act=UnboundedContinuousTensorSpec(shape=shape, device=device),
shape=shape,
),
shape=shape,
)
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
Expand All @@ -646,6 +677,22 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
assert ts["nested_cp"]["act"] is not None


@pytest.mark.parametrize("shape", [(), (2, 3)])
@pytest.mark.parametrize("device", get_default_devices())
def test_create_composite_nested(shape, device):
d = [
{("a", "b"): UnboundedContinuousTensorSpec(shape=shape, device=device)},
{"a": {"b": UnboundedContinuousTensorSpec(shape=shape, device=device)}},
]
for _d in d:
c = CompositeSpec(_d, shape=shape)
assert isinstance(c["a", "b"], UnboundedContinuousTensorSpec)
assert c["a"].shape == torch.Size(shape)
assert c.device == device
assert c["a"].device == device
assert c["a", "b"].device == device


@pytest.mark.parametrize("recurse", [True, False])
def test_lock(recurse):
shape = [3, 4, 5]
Expand Down
4 changes: 3 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2578,6 +2578,8 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
f"Expected a dictionary of specs, but got an argument of type {type(argdict)}."
)
for k, item in argdict.items():
if isinstance(item, dict):
item = CompositeSpec(item, shape=shape)
if item is not None:
if self._device is None:
self._device = item.device
Expand Down Expand Up @@ -2656,7 +2658,7 @@ def __getitem__(self, idx):
def __setitem__(self, key, value):
if isinstance(key, tuple) and len(key) > 1:
if key[0] not in self.keys(True):
self[key[0]] = CompositeSpec()
self[key[0]] = CompositeSpec(shape=self.shape)
self[key[0]][key[1:]] = value
return
elif isinstance(key, tuple):
Expand Down

0 comments on commit d7ca44c

Please sign in to comment.