Skip to content

Commit

Permalink
[BugFix] TensorDictPrimer updates spec instead of overwriting (pytorc…
Browse files Browse the repository at this point in the history
…h#2332)

Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
matteobettini and vmoens authored Jul 31, 2024
1 parent c1093b7 commit 99332f5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
30 changes: 30 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6676,6 +6676,36 @@ def test_single_trans_env_check(self):
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_nested_key_env(self):
env = MultiKeyCountingEnv()
env_obs_spec_prior_primer = env.observation_spec.clone()
env = TransformedEnv(
env,
TensorDictPrimer(
CompositeSpec(
{
"nested_1": CompositeSpec(
{
"mykey": UnboundedContinuousTensorSpec(
(env.nested_dim_1, 4)
)
},
shape=(env.nested_dim_1,),
)
}
),
reset_key="_reset",
),
)
check_env_specs(env)
env_obs_spec_post_primer = env.observation_spec.clone()
assert ("nested_1", "mykey") in env_obs_spec_post_primer.keys(True, True)
del env_obs_spec_post_primer[("nested_1", "mykey")]
assert env_obs_spec_post_primer == env_obs_spec_prior_primer

assert ("nested_1", "mykey") in env.reset().keys(True, True)
assert ("next", "nested_1", "mykey") in env.rollout(3).keys(True, True)

def test_transform_no_env(self):
t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))
td = TensorDict({"a": torch.zeros(())}, [])
Expand Down
26 changes: 12 additions & 14 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4646,7 +4646,7 @@ def __init__(
self.reset_key = reset_key

# sanity check
for spec in self.primers.values():
for spec in self.primers.values(True, True):
if not isinstance(spec, TensorSpec):
raise ValueError(
"The values of the primers must be a subtype of the TensorSpec class. "
Expand Down Expand Up @@ -4705,15 +4705,16 @@ def transform_observation_spec(
raise ValueError(
f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead."
)
for key, spec in self.primers.items():
if spec.shape[: len(observation_spec.shape)] != observation_spec.shape:
expanded_spec = self._expand_shape(spec)
spec = expanded_spec

if self.primers.shape != observation_spec.shape:
try:
device = observation_spec.device
except RuntimeError:
device = self.device
observation_spec[key] = self.primers[key] = spec.to(device)
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expnad them to that shape
self.primers = self._expand_shape(self.primers)
device = observation_spec.device
observation_spec.update(self.primers.clone().to(device))
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
Expand Down Expand Up @@ -4763,8 +4764,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
for key in self.primers.keys():
if key not in next_tensordict.keys(True):
for key in self.primers.keys(True, True):
if key not in next_tensordict.keys(True, True):
prev_val = tensordict.get(key)
next_tensordict.set(key, prev_val)
return next_tensordict
Expand All @@ -4782,9 +4783,6 @@ def _reset(
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.batch_size)] != tensordict.batch_size:
expanded_spec = self._expand_shape(spec)
self.primers[key] = spec = expanded_spec
if self.random:
shape = (
()
Expand Down

0 comments on commit 99332f5

Please sign in to comment.