diff --git a/test/mocking_classes.py b/test/mocking_classes.py index b3ccb9caf51..504634917bc 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -421,12 +421,12 @@ def __new__( shape=batch_size, ) if action_spec is None: - action_spec_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) - action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) + if categorical_action_encoding: + action_spec_cls = DiscreteTensorSpec + action_spec = action_spec_cls(n=7, shape=batch_size) + else: + action_spec_cls = OneHotDiscreteTensorSpec + action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) if done_spec is None: @@ -1053,7 +1053,7 @@ def _step( batch_size=self.batch_size, device=self.device, ) - return tensordict.select().set("next", tensordict) + return tensordict class NestedCountingEnv(CountingEnv): diff --git a/test/test_env.py b/test/test_env.py index a75b095db4c..6ad03208e3d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -143,7 +143,7 @@ def test_env_seed(env_name, frame_skip, seed=0): env.set_seed(seed) td0b = env.fake_tensordict() td0b = env.reset(tensordict=td0b) - td1b = env.step(td0b.clone().set("action", action)) + td1b = env.step(td0b.exclude("next").clone().set("action", action)) assert_allclose_td(td0a, td0b.select(*td0a.keys())) assert_allclose_td(td1a, td1b) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index a9134dd5ce7..ca1e0e46e57 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -7,8 +7,14 @@ import pytest import torch +from mocking_classes import DiscreteActionVecMockEnv from tensordict import pad, TensorDict, unravel_key_list -from tensordict.nn import InteractionType, make_functional, TensorDictModule +from tensordict.nn import ( + InteractionType, + make_functional, + TensorDictModule, + TensorDictSequential, +) from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -21,6 +27,7 @@ DecisionTransformerInferenceWrapper, DTActor, LSTMModule, + MLP, NormalParamWrapper, OnlineDTActor, ProbabilisticActor, @@ -1765,6 +1772,55 @@ def test_multi_consecutive(self, shape): td_ss["intermediate"], td["intermediate"][..., -1, :] ) + def test_lstm_parallel_env(self): + from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + + # tests that hidden states are carried over with parallel envs + lstm_module = LSTMModule( + input_size=7, + hidden_size=12, + num_layers=2, + in_key="observation", + out_key="features", + ) + + def create_transformed_env(): + primer = lstm_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv(categorical_action_encoding=True) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + env = ParallelEnv( + create_env_fn=create_transformed_env, + num_workers=2, + ) + + mlp = TensorDictModule( + MLP( + in_features=12, + out_features=7, + num_cells=[], + ), + in_keys=["features"], + out_keys=["logits"], + ) + + actor_model = TensorDictSequential(lstm_module, mlp) + + actor = ProbabilisticActor( + module=actor_model, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=True, + ) + for break_when_any_done in [False, True]: + data = env.rollout(10, actor, break_when_any_done=break_when_any_done) + assert (data.get("recurrent_state_c") != 0.0).any() + assert (data.get(("next", "recurrent_state_c")) != 0.0).all() + def test_safe_specs(): diff --git a/test/test_transforms.py b/test/test_transforms.py index 40037085e8d..ae4d533cd91 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -430,10 +430,9 @@ def test_transform_env_clone(self): value_at_clone = td["next", "observation"].clone() for _ in range(10): td = env.rand_step(td) - assert (td["next", "observation"] != value_at_clone).any() - assert ( - td["next", "observation"] == env.transform._cat_buffers_observation - ).all() + td = step_mdp(td) + assert (td["observation"] != value_at_clone).any() + assert (td["observation"] == env.transform._cat_buffers_observation).all() assert ( cloned._cat_buffers_observation == env.transform._cat_buffers_observation ).all() @@ -6693,6 +6692,7 @@ def _test_vecnorm_subproc_auto( tensordict = env.reset() for _ in range(10): tensordict = env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put(True) msg = queue_in.get(timeout=TIMEOUT) assert msg == "all_done" @@ -6800,11 +6800,13 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): assert msg == "start" for _ in range(10): tensordict = parallel_env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put("first round") msg = queue_in.get(timeout=TIMEOUT) assert msg == "start" for _ in range(10): tensordict = parallel_env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put("second round") parallel_env.close() queue_out.close() @@ -6884,6 +6886,7 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): for _ in range(N): td = env_t.rand_step(td) tds.append(td.clone()) + td = step_mdp(td) if td.get("done").any(): td = env_t.reset() tds = torch.stack(tds, 0) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 469d6bd65bc..08ed915ce80 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1118,6 +1118,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: Args: tensordict (TensorDictBase): Tensordict containing the action to be taken. + If the input tensordict contains a ``"next"`` entry, the values contained in it + will prevail over the newly computed values. This gives a mechanism + to override the underlying computations. Returns: the input tensordict, modified in place with the resulting observations, done state and reward @@ -1126,10 +1129,13 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: """ # sanity check self._assert_tensordict_shape(tensordict) + next_preset = tensordict.get("next", None) next_tensordict = self._step(tensordict) next_tensordict = self._step_proc_data(next_tensordict) - # tensordict could already have a "next" key + if next_preset is not None: + # tensordict could already have a "next" key + next_tensordict.update(next_preset) tensordict.set("next", next_tensordict) return tensordict @@ -1669,11 +1675,14 @@ def fake_tensordict(self) -> TensorDictBase: next_output.update(fake_reward) next_output.update(fake_done) fake_in_out.update(fake_done.clone()) + if "next" not in fake_in_out.keys(): + fake_in_out.set("next", next_output) + else: + fake_in_out.get("next").update(next_output) - fake_td = fake_in_out.set("next", next_output) - fake_td.batch_size = self.batch_size - fake_td = fake_td.to(self.device) - return fake_td + fake_in_out.batch_size = self.batch_size + fake_in_out = fake_in_out.to(self.device) + return fake_in_out class _EnvWrapper(EnvBase, metaclass=abc.ABCMeta): diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index d511b069612..6baa4ad267d 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -5,11 +5,11 @@ from typing import Optional, Tuple import torch -from tensordict import unravel_key_list +from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import TensorDictModuleBase as ModuleBase -from tensordict.tensordict import NO_DEFAULT, TensorDictBase +from tensordict.tensordict import NO_DEFAULT from tensordict.utils import prod from torch import nn