Skip to content

Commit

Permalink
[BugFix] More robust _StepMDP and multi-purpose envs (pytorch#2038)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 25, 2024
1 parent e835770 commit 1fcd3e3
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 17 deletions.
43 changes: 32 additions & 11 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,7 @@ def test_seed():
torch.testing.assert_close(rollout1["observation"], rollout2["observation"])


@pytest.mark.filterwarnings("error")
class TestStepMdp:
@pytest.mark.parametrize("keep_other", [True, False])
@pytest.mark.parametrize("exclude_reward", [True, False])
Expand Down Expand Up @@ -1339,7 +1340,7 @@ def test_step_class(
env = envcls()

tensordict = env.rand_step(env.reset())
out = step_mdp(
out_func = step_mdp(
tensordict.lock_(),
keep_other=keep_other,
exclude_reward=exclude_reward,
Expand All @@ -1356,8 +1357,8 @@ def test_step_class(
exclude_done=exclude_done,
exclude_action=exclude_action,
)
out2 = step_func(tensordict)
assert (out == out2).all()
out_cls = step_func(tensordict)
assert (out_func == out_cls).all()

@pytest.mark.parametrize("nested_obs", [True, False])
@pytest.mark.parametrize("nested_action", [True, False])
Expand Down Expand Up @@ -1718,6 +1719,25 @@ def test_heterogeenous(
assert td[..., i][nested_other_key].shape == (td_batch_size, 1)
assert (td[..., i][nested_other_key] == 0).all()

@pytest.mark.parametrize("serial", [False, True])
def test_multi_purpose_env(self, serial):
# Tests that even if it's validated, the same env can be used within a collector
# and independently of it.
if serial:
env = SerialEnv(2, ContinuousActionVecMockEnv)
else:
env = ContinuousActionVecMockEnv()
rollout = env.rollout(10)
assert env._step_mdp.validate(None)
c = SyncDataCollector(
env, env.rand_action, frames_per_batch=10, total_frames=20
)
for data in c: # noqa: B007
pass
assert ("collector", "traj_ids") in data.keys(True)
assert env._step_mdp.validate(None)
rollout = env.rollout(10)


@pytest.mark.parametrize("device", get_default_devices())
def test_batch_locked(device):
Expand Down Expand Up @@ -2644,7 +2664,8 @@ def _step(self, tensordict):
"reward": action.sum().unsqueeze(0),
**self.full_done_spec.zero(),
"observation": obs,
}
},
batch_size=[],
)

torch.manual_seed(0)
Expand Down Expand Up @@ -2809,19 +2830,19 @@ def test_single_task_share_individual_td():

def test_stackable():
# Tests the _stackable util
stack = [TensorDict({"a": 0}), TensorDict({"b": 1})]
stack = [TensorDict({"a": 0}, []), TensorDict({"b": 1}, [])]
assert not _stackable(*stack), torch.stack(stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": 1})]
stack = [TensorDict({"a": [0]}, []), TensorDict({"a": 1}, [])]
assert not _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1]})]
stack = [TensorDict({"a": [0]}, []), TensorDict({"a": [1]}, [])]
assert _stackable(*stack)
stack = [TensorDict({"a": [0]}), TensorDict({"a": [1], "b": {}})]
stack = [TensorDict({"a": [0]}, []), TensorDict({"a": [1], "b": {}}, [])]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": [1]}})]
stack = [TensorDict({"a": {"b": [0]}}, []), TensorDict({"a": {"b": [1]}}, [])]
assert _stackable(*stack)
stack = [TensorDict({"a": {"b": [0]}}), TensorDict({"a": {"b": 1}})]
stack = [TensorDict({"a": {"b": [0]}}, []), TensorDict({"a": {"b": 1}}, [])]
assert not _stackable(*stack)
stack = [TensorDict({"a": "a string"}), TensorDict({"a": "another string"})]
stack = [TensorDict({"a": "a string"}, []), TensorDict({"a": "another string"}, [])]
assert _stackable(*stack)


Expand Down
1 change: 1 addition & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10266,6 +10266,7 @@ def test_multistep_transform(self):
rollout = env.rollout(
2, auto_reset=False, tensordict=td, break_when_any_done=False
).contiguous()
assert "reward" not in rollout.keys()
out = t._inv_call(rollout)
td = rollout[..., -1]["next"]
if out is not None:
Expand Down
60 changes: 54 additions & 6 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self.keep_other = keep_other
self.exclude_action = exclude_action

self.exclude_from_root = ["next", *self.done_keys]
self.keys_from_next = list(self.observation_keys)
if not exclude_reward:
self.keys_from_next += self.reward_keys
Expand All @@ -134,8 +135,18 @@ def __init__(
self.keys_from_root = []
if not exclude_action:
self.keys_from_root += self.action_keys
else:
self.exclude_from_root += self.action_keys
if keep_other:
self.keys_from_root += self.state_keys
else:
self.exclude_from_root += self.state_keys

reset_keys = {_replace_last(key, "_reset") for key in self.done_keys}
self.exclude_from_root += list(reset_keys)
self.exclude_from_root += list(self.reward_keys)

self.exclude_from_root = self._repr_key_list_as_tree(self.exclude_from_root)
self.keys_from_root = self._repr_key_list_as_tree(self.keys_from_root)
self.keys_from_next = self._repr_key_list_as_tree(self.keys_from_next)
self.validated = None
Expand All @@ -155,23 +166,26 @@ def validate(self, tensordict):
+ [unravel_key(("next", key)) for key in self.reward_keys]
)
actual = set(tensordict.keys(True, True))
self.validated = set(expected) == actual
expected = set(expected)
self.validated = expected.union(actual) == expected
if not self.validated:
warnings.warn(
"The expected key set and actual key set differ. "
"This will work but with a slower throughput than "
"when the specs match exactly the actual key set "
"in the data. "
f"Expected - Actual keys={set(expected) - actual}, \n"
f"Actual - Expected keys={actual- set(expected)}."
f"{{Expected keys}}-{{Actual keys}}={set(expected) - actual}, \n"
f"{{Actual keys}}-{{Expected keys}}={actual- set(expected)}."
)
return self.validated

@staticmethod
def _repr_key_list_as_tree(key_list):
"""Represents the keys as a tree to facilitate iteration."""
if not key_list:
return {}
key_dict = {key: torch.zeros(()) for key in key_list}
td = TensorDict(key_dict)
td = TensorDict(key_dict, batch_size=torch.Size([]))
return tree_map(lambda x: None, td.to_dict())

@classmethod
Expand Down Expand Up @@ -201,6 +215,36 @@ def _grab_and_place(
data_out._set_str(key, val, validated=True, inplace=False)
return data_out

@classmethod
def _exclude(
cls, nested_key_dict: dict, data_in: TensorDictBase, out: TensorDictBase | None
) -> None:
"""Copies the entries if they're not part of the list of keys to exclude."""
if isinstance(data_in, LazyStackedTensorDict):
if out is None:
out = data_in.empty()
for td, td_out in zip(data_in.tensordicts, out.tensordicts):
cls._exclude(nested_key_dict, td, td_out)
return out
has_set = False
for key, value in data_in.items():
subdict = nested_key_dict.get(key, NO_DEFAULT)
if subdict is NO_DEFAULT:
value = value.copy() if is_tensor_collection(value) else value
if not has_set and out is None:
out = data_in.empty()
out._set_str(key, value, validated=True, inplace=False)
has_set = True
elif subdict is not None:
value = cls._exclude(subdict, value, None)
if value is not None:
if not has_set and out is None:
out = data_in.empty()
out._set_str(key, value, validated=True, inplace=False)
has_set = True
if has_set:
return out

def __call__(self, tensordict):
if isinstance(tensordict, LazyStackedTensorDict):
out = LazyStackedTensorDict.lazy_stack(
Expand All @@ -210,12 +254,16 @@ def __call__(self, tensordict):
return out

next_td = tensordict._get_str("next", None)
out = next_td.empty()
if self.validate(tensordict):
self._grab_and_place(self.keys_from_root, tensordict, out)
if self.keep_other:
out = self._exclude(self.exclude_from_root, tensordict, out=None)
else:
out = next_td.empty()
self._grab_and_place(self.keys_from_root, tensordict, out)
self._grab_and_place(self.keys_from_next, next_td, out)
return out
else:
out = next_td.empty()
total_key = ()
if self.keep_other:
for key in tensordict.keys():
Expand Down

0 comments on commit 1fcd3e3

Please sign in to comment.