Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dynamic specs #2143

Merged
merged 27 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed May 26, 2024
commit f30062d7de8fbb0af05aea408835b95b8932ca8f
54 changes: 54 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,3 +1896,57 @@ def _reset(self, tensordict=None):
reset_td.update(self.full_done_spec.zero())
assert reset_td.batch_size == self.batch_size
return reset_td


class EnvWithDynamicSpec(EnvBase):
def __init__(self, max_count=5):
super().__init__(batch_size=())
self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)),
)
self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,))
self.full_done_spec = CompositeSpec(
done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
)
self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float)
self.count = 0
self.max_count = max_count

def _reset(self, tensordict=None):
self.count = 0
data = TensorDict(
{
"observation": torch.full(
(3, self.count + 1, 2),
self.count,
dtype=self.observation_spec["observation"].dtype,
)
}
)
data.update(self.done_spec.zero())
return data

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1
done = self.count >= self.max_count
observation = TensorDict(
{
"observation": torch.full(
(3, self.count + 1, 2),
self.count,
dtype=self.observation_spec["observation"].dtype,
)
}
)
done = self.full_done_spec.zero() | done
reward = self.full_reward_spec.zero()
return observation.update(done).update(reward)

def _set_seed(self, seed: Optional[int]):
self.manual_seed = seed
return seed
13 changes: 13 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvWithDynamicSpec,
HeterogeneousCountingEnv,
HeterogeneousCountingEnvPolicy,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -3049,6 +3050,18 @@ def policy(td):
assert (lazy_root["lidar"][1:][done[:-1].squeeze()] == 0).all()


class TestEnvWithDynamicSpec:
def test_dynamic_rollout(self):
env = EnvWithDynamicSpec()
with pytest.raises(
RuntimeError,
match="The environment specs are dynamic. Call rollout with return_contiguous=False",
):
rollout = env.rollout(4)
rollout = env.rollout(4, return_contiguous=False)
check_env_specs(env, return_contiguous=False)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
23 changes: 13 additions & 10 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3467,19 +3467,22 @@ def test_all(self):
disc=disc,
moneh=moneh,
mdisc=mdisc,
shape=(-1, 1, 2)
shape=(-1, 1, 2),
)
assert spec.shape == (-1, 1, 2)

data = TensorDict({
"unb": xunb,
"unbd": xunbd,
"bound": xbound,
"oneh": xoneh,
"disc": xdisc,
"moneh": xmoneh,
"mdisc": xmdisc,
}, [3, 1, 2])
data = TensorDict(
{
"unb": xunb,
"unbd": xunbd,
"bound": xbound,
"oneh": xoneh,
"disc": xdisc,
"moneh": xmoneh,
"mdisc": xmdisc,
},
[3, 1, 2],
)
assert spec.is_in(data)


Expand Down
25 changes: 16 additions & 9 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,11 @@ def ndim(self):
def ndimension(self):
return len(self.shape)

@property
def _safe_shape(self):
"""Returns a shape where all heterogeneous values are replaced by one (to be expandable)."""
return torch.Size([v if v > 0 else 1 for v in self.shape])

@abc.abstractmethod
def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
"""Indexes the input tensor.
Expand Down Expand Up @@ -791,7 +796,9 @@ def zero(self, shape=None) -> torch.Tensor:
"""
if shape is None:
shape = torch.Size([])
return torch.zeros((*shape, *self.shape), dtype=self.dtype, device=self.device)
return torch.zeros(
(*shape, *self._safe_shape), dtype=self.dtype, device=self.device
)

@abc.abstractmethod
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec":
Expand Down Expand Up @@ -1653,14 +1660,14 @@ def __init__(
else:
shape_corr = None
if high.ndimension():
if shape_corr is not None and shape_corr != high.shape_corr:
if shape_corr is not None and shape_corr != high.shape:
raise RuntimeError(err_msg)
shape = high.shape_corr
shape = high.shape
low = low.expand(shape_corr).clone()
elif low.ndimension():
if shape_corr is not None and shape_corr != low.shape_corr:
if shape_corr is not None and shape_corr != low.shape:
raise RuntimeError(err_msg)
shape = low.shape_corr
shape = low.shape
high = high.expand(shape_corr).clone()
elif shape_corr is None:
raise RuntimeError(err_msg)
Expand All @@ -1673,7 +1680,7 @@ def __init__(
elif high.numel() > low.numel():
low = low.expand_as(high).clone()
if shape_corr is None:
shape = low.shape_corr
shape = low.shape
else:
if isinstance(shape_corr, float):
shape_corr = torch.Size([shape_corr])
Expand Down Expand Up @@ -1973,21 +1980,21 @@ def rand(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self.shape), device=self.device
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
)

def zero(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self.shape), device=self.device
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
)

def one(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self.shape), device=self.device
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
)

def is_in(self, val: torch.Tensor) -> bool:
Expand Down
30 changes: 25 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,16 @@ def specs(self) -> CompositeSpec:
shape=self.batch_size,
).lock_()

@property
def _has_dynamic_specs(self) -> bool:
return any(
any(s == -1 for s in spec.shape)
for spec in self.output_spec.values(True, True)
) or any(
any(s == -1 for s in spec.shape)
for spec in self.input_spec.values(True, True)
)

def rollout(
self,
max_steps: int,
Expand Down Expand Up @@ -2564,9 +2574,19 @@ def rollout(
tensordicts = self._rollout_nonstop(**kwargs)
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
if return_contiguous:
out_td = torch.stack(tensordicts, len(batch_size), out=out)
try:
out_td = torch.stack(tensordicts, len(batch_size), out=out)
except RuntimeError as err:
if (
"The shapes of the tensors to stack is incompatible" in str(err)
and self._has_dynamic_specs
):
raise RuntimeError(
"The environment specs are dynamic. Call rollout with return_contiguous=False."
)
raise
else:
out_td = LazyStackedTensorDict.lazy_stack(
out_td = LazyStackedTensorDict.maybe_dense_stack(
tensordicts, len(batch_size), out=out
)
if set_truncated:
Expand Down Expand Up @@ -2927,9 +2947,12 @@ def fake_tensordict(self) -> TensorDictBase:
full_done_spec = self.output_spec["full_done_spec"]

fake_obs = observation_spec.zero()
fake_reward = reward_spec.zero()
fake_done = full_done_spec.zero()

fake_state = state_spec.zero()
fake_action = action_spec.zero()

if any(
isinstance(val, LazyStackedTensorDict) for val in fake_action.values(True)
):
Expand All @@ -2941,9 +2964,6 @@ def fake_tensordict(self) -> TensorDictBase:
# Hence we generate the input, and override using the output
fake_in_out = fake_input.update(fake_obs)

fake_reward = reward_spec.zero()
fake_done = full_done_spec.zero()

next_output = fake_obs.clone()
next_output.update(fake_reward)
next_output.update(fake_done)
Expand Down
36 changes: 24 additions & 12 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,20 +748,32 @@ def check_env_specs(
- List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}.
"""
)
if (
fake_tensordict_select.apply(lambda x: torch.zeros_like(x))
!= real_tensordict_select.apply(lambda x: torch.zeros_like(x))
).any():
raise AssertionError(
"zeroing the two tensordicts did not make them identical. "
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
)

# Checks shapes and eventually dtypes of keys at all nesting levels
_per_level_env_check(
fake_tensordict_select, real_tensordict_select, check_dtype=check_dtype
zeroing_err_msg = (
"zeroing the two tensordicts did not make them identical. "
"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
)

if env._has_dynamic_specs:
for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)):
fake = fake.apply(lambda x, y: x.expand_as(y), real)
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
raise AssertionError(zeroing_err_msg)

# Checks shapes and eventually dtypes of keys at all nesting levels
_per_level_env_check(fake, real, check_dtype=check_dtype)

else:
if (
torch.zeros_like(fake_tensordict_select)
!= torch.zeros_like(real_tensordict_select)
).any():
raise AssertionError(zeroing_err_msg)

# Checks shapes and eventually dtypes of keys at all nesting levels
_per_level_env_check(
fake_tensordict_select, real_tensordict_select, check_dtype=check_dtype
)

# Check specs
last_td = real_tensordict[..., -1]
last_td = env.rand_action(last_td)
Expand Down
Loading