Skip to content

Commit

Permalink
[BugFix] Fixes to RenameTransform (#2442)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
thomasbbrunner and vmoens authored Sep 30, 2024
1 parent b4d543e commit a0dfddc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
22 changes: 22 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9331,6 +9331,28 @@ def test_transform_inverse(self, create_copy):
else:
assert "b" not in tensordict.keys()

def test_rename_action(self, create_copy):
base_env = ContinuousActionVecMockEnv()
env = base_env.append_transform(
RenameTransform(
in_keys=[],
out_keys=[],
in_keys_inv=["action"],
out_keys_inv=[("renamed", "action")],
create_copy=create_copy,
)
)
r = env.rollout(3)
assert ("renamed", "action") in env.action_keys, env.action_keys
assert ("renamed", "action") in r
assert env.full_action_spec[("renamed", "action")] is not None
if create_copy:
assert "action" in env.action_keys
assert "action" in r
else:
assert "action" not in env.action_keys
assert "action" not in r


class TestInitTracker(TransformBase):
@pytest.mark.skipif(not _has_gym, reason="no gym detected")
Expand Down
26 changes: 13 additions & 13 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6634,15 +6634,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:


class RenameTransform(Transform):
"""A transform to rename entries in the output tensordict.
"""A transform to rename entries in the output tensordict (or input tensordict via the inverse keys).
Args:
in_keys (sequence of NestedKey): the entries to rename
in_keys (sequence of NestedKey): the entries to rename.
out_keys (sequence of NestedKey): the name of the entries after renaming.
in_keys_inv (sequence of NestedKey, optional): the entries to rename before
passing the input tensordict to :meth:`EnvBase._step`.
out_keys_inv (sequence of NestedKey, optional): the names of the renamed
entries passed to :meth:`EnvBase._step`.
in_keys_inv (sequence of NestedKey, optional): the entries to rename
in the input tensordict, which will be passed to :meth:`EnvBase._step`.
out_keys_inv (sequence of NestedKey, optional): the names of the entries
in the input tensordict after renaming.
create_copy (bool, optional): if ``True``, the entries will be copied
with a different name rather than being renamed. This allows for
renaming immutable entries such as ``"reward"`` and ``"done"``.
Expand Down Expand Up @@ -6713,7 +6713,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance)
for in_key, out_key in zip(self.in_keys, self.out_keys):
try:
tensordict.rename_key_(in_key, out_key)
out.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
Expand Down Expand Up @@ -6802,9 +6802,9 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:

def transform_input_spec(self, input_spec: Composite) -> Composite:
for action_key in self.parent.action_keys:
if action_key in self.in_keys:
for i, out_key in enumerate(self.out_keys): # noqa: B007
if self.in_keys[i] == action_key:
if action_key in self.in_keys_inv:
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
if self.in_keys_inv[i] == action_key:
break
else:
# unreachable
Expand All @@ -6815,9 +6815,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
if not self.create_copy:
del input_spec["full_action_spec"][action_key]
for state_key in self.parent.full_state_spec.keys(True):
if state_key in self.in_keys:
for i, out_key in enumerate(self.out_keys): # noqa: B007
if self.in_keys[i] == state_key:
if state_key in self.in_keys_inv:
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
if self.in_keys_inv[i] == state_key:
break
else:
# unreachable
Expand Down

0 comments on commit a0dfddc

Please sign in to comment.