From 99332f5ce4b4716908119089d2f10e3006de6923 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 31 Jul 2024 22:44:54 +0200 Subject: [PATCH] [BugFix] TensorDictPrimer updates spec instead of overwriting (#2332) Co-authored-by: Vincent Moens --- test/test_transforms.py | 30 +++++++++++++++++++++++++++ torchrl/envs/transforms/transforms.py | 26 +++++++++++------------ 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 94ec8b2716c..c38908eba1d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6676,6 +6676,36 @@ def test_single_trans_env_check(self): assert "mykey" in env.reset().keys() assert ("next", "mykey") in env.rollout(3).keys(True) + def test_nested_key_env(self): + env = MultiKeyCountingEnv() + env_obs_spec_prior_primer = env.observation_spec.clone() + env = TransformedEnv( + env, + TensorDictPrimer( + CompositeSpec( + { + "nested_1": CompositeSpec( + { + "mykey": UnboundedContinuousTensorSpec( + (env.nested_dim_1, 4) + ) + }, + shape=(env.nested_dim_1,), + ) + } + ), + reset_key="_reset", + ), + ) + check_env_specs(env) + env_obs_spec_post_primer = env.observation_spec.clone() + assert ("nested_1", "mykey") in env_obs_spec_post_primer.keys(True, True) + del env_obs_spec_post_primer[("nested_1", "mykey")] + assert env_obs_spec_post_primer == env_obs_spec_prior_primer + + assert ("nested_1", "mykey") in env.reset().keys(True, True) + assert ("next", "nested_1", "mykey") in env.rollout(3).keys(True, True) + def test_transform_no_env(self): t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])) td = TensorDict({"a": torch.zeros(())}, []) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1a66ee489a6..7c9dec980f5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4646,7 +4646,7 @@ def __init__( self.reset_key = reset_key # sanity check - for spec in self.primers.values(): + for spec in self.primers.values(True, True): if not isinstance(spec, TensorSpec): raise ValueError( "The values of the primers must be a subtype of the TensorSpec class. " @@ -4705,15 +4705,16 @@ def transform_observation_spec( raise ValueError( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) - for key, spec in self.primers.items(): - if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - expanded_spec = self._expand_shape(spec) - spec = expanded_spec + + if self.primers.shape != observation_spec.shape: try: - device = observation_spec.device - except RuntimeError: - device = self.device - observation_spec[key] = self.primers[key] = spec.to(device) + # We try to set the primer shape to the observation spec shape + self.primers.shape = observation_spec.shape + except ValueError: + # If we fail, we expnad them to that shape + self.primers = self._expand_shape(self.primers) + device = observation_spec.device + observation_spec.update(self.primers.clone().to(device)) return observation_spec def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: @@ -4763,8 +4764,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - for key in self.primers.keys(): - if key not in next_tensordict.keys(True): + for key in self.primers.keys(True, True): + if key not in next_tensordict.keys(True, True): prev_val = tensordict.get(key) next_tensordict.set(key, prev_val) return next_tensordict @@ -4782,9 +4783,6 @@ def _reset( _reset = _get_reset(self.reset_key, tensordict) if _reset.any(): for key, spec in self.primers.items(True, True): - if spec.shape[: len(tensordict.batch_size)] != tensordict.batch_size: - expanded_spec = self._expand_shape(spec) - self.primers[key] = spec = expanded_spec if self.random: shape = ( ()