Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Gym compatibility: Terminal and truncated #1539

Merged
merged 181 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
181 commits
Select commit Hold shift + click to select a range
9e5f303
init
vmoens Sep 15, 2023
03b6089
fix
vmoens Sep 15, 2023
512b596
amend
vmoens Sep 17, 2023
3fa8ac0
amend
vmoens Sep 17, 2023
162aa6e
amend
vmoens Sep 17, 2023
2c5cddc
amend
vmoens Sep 17, 2023
c703a02
amend
vmoens Sep 17, 2023
2b78f49
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 17, 2023
822f42a
amend
vmoens Sep 18, 2023
db8e0c1
amend
vmoens Sep 18, 2023
9c7e1a5
amend
vmoens Sep 18, 2023
99bcc4a
amend
vmoens Sep 18, 2023
9b0069e
lint
vmoens Sep 18, 2023
ac43a7e
fix step counter
vmoens Sep 18, 2023
a822407
amend
vmoens Sep 18, 2023
6cec6e1
amend
vmoens Sep 18, 2023
0612e09
amend
vmoens Sep 18, 2023
02902db
Update torchrl/envs/gym_like.py
vmoens Sep 19, 2023
7e22c55
Update torchrl/envs/gym_like.py
vmoens Sep 19, 2023
53927bc
Update torchrl/envs/gym_like.py
vmoens Sep 19, 2023
43ff66a
amend
vmoens Sep 19, 2023
77dcd09
rollout
vmoens Sep 19, 2023
f404245
fix
vmoens Sep 19, 2023
a79b57c
fix
vmoens Sep 19, 2023
149781f
fix
vmoens Sep 19, 2023
71a05b1
fix
vmoens Sep 19, 2023
87bad08
remove prints
vmoens Sep 19, 2023
92c3814
amend
vmoens Sep 19, 2023
42d4c40
amend
vmoens Sep 19, 2023
aec627c
amend
vmoens Sep 19, 2023
b2303bd
amend
vmoens Sep 19, 2023
f6a497b
amend
vmoens Sep 19, 2023
aac630f
amend
vmoens Sep 19, 2023
606ee3a
lint and fixes
vmoens Sep 19, 2023
6ba0d38
amend
vmoens Sep 19, 2023
76b3f0c
amend
vmoens Sep 19, 2023
035c274
amend
vmoens Sep 19, 2023
8bd932f
amend
vmoens Sep 19, 2023
c789e50
amend
vmoens Sep 19, 2023
3e93f13
amend
vmoens Sep 19, 2023
cba97b1
amend
vmoens Sep 19, 2023
7ec7c78
amend
vmoens Sep 19, 2023
dd4c45e
amend
vmoens Sep 19, 2023
0ea0716
fix robohive
vmoens Sep 19, 2023
16d688e
amend
vmoens Sep 20, 2023
268dbd7
amend
vmoens Sep 20, 2023
d77d1cd
amend
vmoens Sep 20, 2023
15bd9fa
amend
vmoens Sep 20, 2023
1b656f7
amend
vmoens Sep 20, 2023
2f13c95
amend
vmoens Sep 20, 2023
aa5de06
amend
vmoens Sep 20, 2023
284262f
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 20, 2023
afcc527
amend
vmoens Sep 20, 2023
9eae41d
amend
vmoens Sep 20, 2023
9210f3b
amend
vmoens Sep 20, 2023
b31b2f0
amend
vmoens Sep 21, 2023
4e8acc0
init
vmoens Sep 21, 2023
c8579f9
init
vmoens Sep 21, 2023
d95989c
Merge branch 'fix_dreamer_tests' into threads_mp
vmoens Sep 21, 2023
e22c318
Merge branch 'terminal_truncated' into myo_threaded
vmoens Sep 21, 2023
a5e9ce3
prints
vmoens Sep 21, 2023
b6e83d5
amend
vmoens Sep 21, 2023
acf89e6
amend
vmoens Sep 21, 2023
16fba2e
amend
vmoens Sep 21, 2023
697c523
amend
vmoens Sep 21, 2023
369492d
fix
vmoens Sep 21, 2023
c50263c
Merge branch 'main' into terminal_truncated
vmoens Sep 21, 2023
de82499
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
e14dcd1
fix
vmoens Sep 21, 2023
4a8424c
amend
vmoens Sep 21, 2023
5056b9e
amend
vmoens Sep 21, 2023
ce26e13
amend
vmoens Sep 21, 2023
5a95850
amend
vmoens Sep 21, 2023
5e38d70
Update torchrl/collectors/collectors.py
vmoens Sep 21, 2023
9eb1c98
amend
vmoens Sep 21, 2023
bccbf67
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
a285780
amend
vmoens Sep 21, 2023
40a8e83
Merge branch 'terminal_truncated' into myo_threaded
vmoens Sep 21, 2023
d8f9505
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 21, 2023
e1eba40
Merge remote-tracking branch 'origin/main' into threads_mp
vmoens Sep 21, 2023
9f58a8d
lint
vmoens Sep 21, 2023
1c4f35f
amend
vmoens Sep 21, 2023
93bd2e6
Merge branch 'threads_mp' into terminal_truncated
vmoens Sep 21, 2023
f6e09e3
amend
vmoens Sep 21, 2023
2cd07c1
amend
vmoens Sep 22, 2023
acf6118
amend
vmoens Sep 22, 2023
bb52ce1
tests
vmoens Sep 22, 2023
0d0bc3c
Merge branch 'main' into terminal_truncated
vmoens Sep 22, 2023
3ef139b
amend
vmoens Sep 22, 2023
0d3ba02
amend
vmoens Sep 22, 2023
0b32209
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 22, 2023
0638e13
amend
vmoens Sep 22, 2023
a8ed5e5
amend
vmoens Sep 22, 2023
54db14e
amend
vmoens Sep 22, 2023
0122fac
amend
vmoens Sep 23, 2023
b18e6e7
amend
vmoens Sep 23, 2023
2f99f16
amend
vmoens Sep 23, 2023
6aa0c9e
amend
vmoens Sep 23, 2023
2b42070
amend
vmoens Sep 23, 2023
7279b12
amend
vmoens Sep 23, 2023
f5ab14d
amend
vmoens Sep 23, 2023
4a9b6b9
amend
vmoens Sep 23, 2023
696324b
amend
vmoens Sep 23, 2023
8890911
Update docs/source/reference/envs.rst
vmoens Sep 24, 2023
e24c2f3
add doc
vmoens Sep 24, 2023
c029f12
amend
vmoens Sep 24, 2023
b37129d
amend
vmoens Sep 24, 2023
9afc783
amend
vmoens Sep 24, 2023
f65622a
amend
vmoens Sep 24, 2023
989eecf
fix VIP
vmoens Sep 24, 2023
77559e0
lint
vmoens Sep 24, 2023
117e41e
osx_skips
vmoens Sep 24, 2023
19fdc33
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 24, 2023
8326674
amend
vmoens Sep 24, 2023
90da3fd
refactor tests
vmoens Sep 25, 2023
f4ddb92
Refactoring: terminated, truncated, done
vmoens Sep 25, 2023
7988118
amend
vmoens Sep 25, 2023
85d5035
let _step return partial done in batched envs
vmoens Sep 25, 2023
b4fa338
fix mocking classes
vmoens Sep 25, 2023
648792d
more fixes
vmoens Sep 25, 2023
9777332
fix step count in equivalence test
vmoens Sep 25, 2023
9063b5b
fix transforms
vmoens Sep 25, 2023
f47e49d
fix transforms
vmoens Sep 25, 2023
c238e9b
fix transformed env
vmoens Sep 25, 2023
502a59b
amend
vmoens Sep 25, 2023
73af141
amend
vmoens Sep 25, 2023
983e246
amend
vmoens Sep 26, 2023
555bca9
amend
vmoens Sep 26, 2023
85f58de
remove calls to done_key
vmoens Sep 26, 2023
6fc39d0
fix step counter
vmoens Sep 26, 2023
175b701
vec envs
vmoens Sep 26, 2023
92e7826
vec envs
vmoens Sep 26, 2023
cd6eaea
amend
vmoens Sep 26, 2023
23e139b
amend
vmoens Sep 26, 2023
677c408
amend
vmoens Sep 26, 2023
3c79081
d4rl
vmoens Sep 26, 2023
54e75b0
d4rl unsqueeze
vmoens Sep 26, 2023
244e3d3
amend
vmoens Sep 26, 2023
09abb71
amend
vmoens Sep 26, 2023
384db24
minor
vmoens Sep 26, 2023
3693cac
amend
vmoens Sep 26, 2023
ca54133
amend
vmoens Sep 26, 2023
0fdc522
amend
vmoens Sep 26, 2023
fdad78f
test_terminated_or_truncated_spec
vmoens Sep 26, 2023
f454e11
more fixes
vmoens Sep 26, 2023
88cee59
--capture no
vmoens Sep 26, 2023
57ccb63
attempt to limit collector idle time
vmoens Sep 26, 2023
e5b0d23
lint
vmoens Sep 26, 2023
dfe726f
amend
vmoens Sep 26, 2023
4f6ce90
amend
vmoens Sep 26, 2023
b16b939
amend
vmoens Sep 26, 2023
daaaddd
amend
vmoens Sep 26, 2023
cd4811f
amend
vmoens Sep 26, 2023
c0f3137
fixes
vmoens Sep 26, 2023
8f9d8fe
fix r3m, vip and vc1
vmoens Sep 27, 2023
4fdf437
fix robohive, d4rl
vmoens Sep 27, 2023
4f44579
amend
vmoens Sep 27, 2023
7787b28
amend
vmoens Sep 27, 2023
5c964a9
amend
vmoens Sep 27, 2023
03001bc
amend
vmoens Sep 27, 2023
04bbee9
lint
vmoens Sep 27, 2023
6ac830e
adapt tests
vmoens Sep 27, 2023
6fbe8c0
amend
vmoens Sep 27, 2023
28fefb6
lint
vmoens Sep 27, 2023
8256e2f
fix gym 0.19
vmoens Sep 27, 2023
1619b09
missing deps
vmoens Sep 27, 2023
23926b2
fix gym truncated
vmoens Sep 27, 2023
e3b8253
fix gym truncated (bis)
vmoens Sep 27, 2023
2a4a1b6
amend
vmoens Sep 27, 2023
7f4c38b
amend
vmoens Sep 27, 2023
2e54626
Merge branch 'main' into terminal_truncated
vmoens Sep 27, 2023
977488e
Merge remote-tracking branch 'origin/main' into terminal_truncated
vmoens Sep 27, 2023
4f932e0
amend
vmoens Sep 27, 2023
5e6a0f0
amend
vmoens Sep 27, 2023
411f097
lint
vmoens Sep 27, 2023
39ed4c3
final (?)
vmoens Sep 28, 2023
21ea856
addressing review
vmoens Sep 28, 2023
f0ee4dd
more fixes
vmoens Sep 28, 2023
7906387
amend
vmoens Sep 28, 2023
72c1240
cloning dones
vmoens Sep 28, 2023
2c7ffb0
amend
vmoens Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
final (?)
  • Loading branch information
vmoens committed Sep 28, 2023
commit 39ed4c3f9827f2adff72aa7f3f2a6fec05edc174
3 changes: 2 additions & 1 deletion .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ do

echo "Testing gym version: ${GYM_VERSION}"
# handling https://github.com/openai/gym/issues/3202
pip install wheel==0.38.4
pip3 install wheel==0.38.4
pip3 install gym==$GYM_VERSION
$DIR/run_test.sh

Expand All @@ -69,6 +69,7 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install wheel==0.38.4
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install ale-py==0.7
$DIR/run_test.sh
Expand Down
21 changes: 8 additions & 13 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,17 @@ delivery and the ``"next"`` entry is gathered by the :func:`~.utils.step_mdp`
function.

.. note::

The Gym(nasium) API recently shifted to a splitting of the ``"done"`` state
into a ``termination`` (the env is done and results should not be trusted)
and ``truncation`` (an external limit on the number of steps is reached) flags.
In TorchRL, ``"done"`` strictly refers to ``termination | truncation``.
In general, all TorchRL environment have a ``"done"`` and ``"terminated"``
entry in their output tensordict. If they are not present by design,
the :class:`~.EnvBase` metaclass will ensure that every done or truncated
is flanked with its dual.
vmoens marked this conversation as resolved.
Show resolved Hide resolved
In TorchRL, ``"done"`` strictly refers to the union of all the end-of-trajectory
signals and should be interpreted as "the last step of a trajectory" or
equivalently "a signal indicating the need to reset".
If the environment provides it (eg, Gymnasium), the truncation entry is also
written in the :meth:`EnvBase.step` output under a ``"truncated"`` entry.
If the environment carries a single value, it will interpreted as a ``"done"``
If the environment carries a single value, it will interpreted as a ``"terminated"``
signal by default.
Some classes in TorchRL may require a ``"terminated"`` signal (eg, value functions).
If none is available, they will fall back on ``"done"`` instead.
The caveat of this choice is that adding a truncation transform (eg, :class:`.StepCounter`)
will override the content of the ``"done"`` signal. If this is a problem
a :class:`~.RenameTransform` should be used to move or copy the ``"done"``
entry (for instance to ``"terminated"``).

By default, TorchRL's collectors and rollout methods will be looking for the ``"done"``
entry to assess if the environment should be reset.

Expand Down
9 changes: 8 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,10 +1534,12 @@ def test_transform_compose(self, max_steps, device, batch, reset_workers):
_reset = torch.randn(done.shape, device=device) < 0
td.set("_reset", _reset)
td.set("done", _reset)
td.set("terminated", _reset)
td.set(("next", "terminated"), done)
td.set(("next", "done"), done)
td.set("step_count", torch.zeros(*batch, 1, dtype=torch.int))
step_counter[0]._step_count_keys = ["step_count"]
step_counter[0]._terminated_keys = ["completed"]
step_counter[0]._terminated_keys = ["terminated"]
step_counter[0]._truncated_keys = ["truncated"]
step_counter[0]._reset_keys = ["_reset"]
step_counter[0]._done_keys = ["done"]
Expand All @@ -1554,6 +1556,7 @@ def test_transform_compose(self, max_steps, device, batch, reset_workers):
)
td = step_mdp(td)
td["next", "done"] = done
td["next", "terminated"] = done
if max_steps is None:
break

Expand Down Expand Up @@ -1592,11 +1595,14 @@ def test_transform_no_env(self, max_steps, device, batch, reset_workers):
while not _reset.any() and reset_workers:
_reset = torch.randn(done.shape, device=device) < 0
td.set("_reset", _reset)
td.set("terminated", _reset)
td.set(("next", "terminated"), done)
td.set("done", _reset)
td.set(("next", "done"), done)
td.set("step_count", torch.zeros(*batch, 1, dtype=torch.int))
step_counter._step_count_keys = ["step_count"]
step_counter._done_keys = ["done"]
step_counter._terminated_keys = ["terminated"]
step_counter._truncated_keys = ["truncated"]
step_counter._reset_keys = ["_reset"]
step_counter._completed_keys = ["completed"]
Expand All @@ -1613,6 +1619,7 @@ def test_transform_no_env(self, max_steps, device, batch, reset_workers):
)
td = step_mdp(td)
td["next", "done"] = done
td["next", "terminated"] = done
if max_steps is None:
break

Expand Down
6 changes: 0 additions & 6 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,12 +883,6 @@ def rollout(self) -> TensorDictBase:
self._tensordict_out.ndim - 1,
out=self._tensordict_out,
)
except KeyError:
print("\n\n err during stack")
print("tensordict list", tensordicts)
print("dest", self._tensordict_out)
print("env", self.env)
raise
return self._tensordict_out

def reset(self, index=None, **kwargs) -> None:
Expand Down
37 changes: 32 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def action_spec(self, value: TensorSpec) -> None:
self.input_spec.lock_()

@property
def full_action_spec(self):
def full_action_spec(self) -> CompositeSpec:
"""The full action spec.

``full_action_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
Expand All @@ -678,6 +678,10 @@ def full_action_spec(self):
"""
return self.input_spec["full_action_spec"]

@full_action_spec.setter
def full_action_spec(self, spec: CompositeSpec) -> None:
self.action_spec = spec

# Reward spec
def _get_reward_keys(self):
keys = self.output_spec["full_reward_spec"].keys(True, True)
Expand Down Expand Up @@ -846,7 +850,7 @@ def reward_spec(self, value: TensorSpec) -> None:
self.output_spec.lock_()

@property
def full_reward_spec(self):
def full_reward_spec(self) -> CompositeSpec:
"""The full reward spec.

``full_reward_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
Expand All @@ -872,6 +876,10 @@ def full_reward_spec(self):
"""
return self.output_spec["full_reward_spec"]

@full_reward_spec.setter
def full_reward_spec(self, spec: CompositeSpec) -> None:
self.reward_spec = spec

# done spec
def _get_done_keys(self):
if "full_done_spec" not in self.output_spec.keys():
Expand Down Expand Up @@ -914,7 +922,7 @@ def done_key(self):
return self.done_keys[0]

@property
def full_done_spec(self):
def full_done_spec(self) -> CompositeSpec:
"""The full done spec.

``full_done_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
Expand Down Expand Up @@ -944,6 +952,10 @@ def full_done_spec(self):
"""
return self.output_spec["full_done_spec"]

@full_done_spec.setter
def full_done_spec(self, spec: CompositeSpec) -> None:
self.done_spec = spec

# Done spec: done specs belong to output_spec
@property
def done_spec(self) -> TensorSpec:
Expand Down Expand Up @@ -1177,7 +1189,13 @@ def observation_spec(self, value: TensorSpec) -> None:
finally:
self.output_spec.lock_()

full_observation_spec = observation_spec
@property
def full_observation_spec(self) -> CompositeSpec:
return self.observation_spec

@full_observation_spec.setter
def full_observation_spec(self, spec: CompositeSpec):
self.observation_spec = spec

# state spec: state specs belong to input_spec
@property
Expand Down Expand Up @@ -1246,7 +1264,7 @@ def state_spec(self, value: CompositeSpec) -> None:
self.input_spec.lock_()

@property
def full_state_spec(self):
def full_state_spec(self) -> CompositeSpec:
"""The full state spec.

``full_state_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
Expand All @@ -1272,6 +1290,10 @@ def full_state_spec(self):
"""
return self.state_spec

@full_state_spec.setter
def full_state_spec(self, spec: CompositeSpec) -> None:
self.state_spec = spec

def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Makes a step in the environment.

Expand Down Expand Up @@ -1770,6 +1792,7 @@ def policy(td):
tensordicts.append(tensordict.clone(False))

if i == max_steps - 1:
# we don't truncated as one could potentially continue the run
break
tensordict = step_mdp(
tensordict,
Expand Down Expand Up @@ -1810,6 +1833,8 @@ def reset_keys(self) -> List[NestedKey]:
settings. They are structured as ``(*prefix, "_reset")`` where ``prefix`` is
a (possibly empty) tuple of strings pointing to a tensordict location
where a done state can be found.

The value of reset_keys is cached.
"""
reset_keys = self.__dict__.get("_reset_keys", None)
if reset_keys is not None:
Expand Down Expand Up @@ -1844,6 +1869,8 @@ def done_keys_groups(self):
inner lists contain the done keys (eg, done and truncated) that can
be read to determine a reset when it is absent.

The value of ``done_keys_groups`` is cached.

"""
done_keys_sorted = self.__dict__.get("_done_keys_groups", None)
if done_keys_sorted is not None:
Expand Down
27 changes: 13 additions & 14 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,23 @@ def read_action(self, action):

def read_done(
self,
terminated: bool,
terminated: bool | None = None,
truncated: bool | None = None,
done: bool | None = None,
) -> Tuple[bool | np.ndarray, bool | np.ndarray, bool | np.ndarray, bool]:
"""Done state reader.

In torchrl, a `"done"` signal means that a trajectory has reach its end,
either because it has been interrupted or because it is terminated.
Truncated means the trajectory has been interrupted early.
Terminated means the task is finished.
Truncated means the episode has been interrupted early.
Terminated means the task is finished, the episode is completed.

Args:
terminated (np.ndarray, boolean or other format): completion state obtained
from the environment.
``"terminated"`` equates to ``"termination"`` in gymnasium: the signal that
the environment has reached the end of the game, any data coming
after this should be considered as nonsensical.
terminated (np.ndarray, boolean or other format): completion state
obtained from the environment.
``"terminated"`` equates to ``"termination"`` in gymnasium:
the signal that the environment has reached the end of the
episode, any data coming after this should be considered as nonsensical.
Defaults to ``None``.
truncated (bool or None): early truncation signal.
Defaults to ``None``.
Expand Down Expand Up @@ -315,26 +315,25 @@ def _output_transform(
"""A method to read the output of the env step.

Must return a tuple: (obs, reward, terminated, truncated, done, info).
If only one end-of-trajectory is passed, it is interpreted as ``"done"``
(unspecified end-of-traj).
If only one end-of-trajectory is passed, it is interpreted as ``"truncated"``.
An attempt to retrieve ``"truncated"`` from the info dict is also undertaken.
If 2 are passed (like in gymnasium), we interpret them as ``"terminated",
"truncated"`` (``"truncated"`` meaning that the trajectory has been
interrupted early), and ``"done"`` is the union of the two,
ie. the unspecified end-of-trajectory signal.

These three concepts have different usage:

- ``"terminated"`` means that one should not pay attention to the
- ``"terminated"`` indicated the final stage of a Markov Decision
Process. It means that one should not pay attention to the
upcoming observations (eg., in value functions) as they should be
regarded as not valid.
This is a "game-over" situation, the result of the action is the
end of the game (win or loose).
- ``"truncated"`` means that the environment has reached a stage where
we decided to stop the collection for some reason but the next
observation should not be discarded. If it were not for this
arbitrary decision, the collection could have proceeded further.
- ``"done"`` is either one or the other. It is to be interpreted as
"a reset should be called at the next step".
"a reset should be called before the next step is undertaken".

"""
...
Expand Down
48 changes: 31 additions & 17 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4424,6 +4424,21 @@ def done_keys(self):
self.__dict__["_done_keys"] = done_keys
return done_keys

@property
def terminated_keys(self):
terminated_keys = self.__dict__.get("_terminated_keys", None)
if terminated_keys is None:
# make the default terminated keys
terminated_keys = []
for (terminated_key, *_) in self.parent.done_keys_groups:
if isinstance(terminated_key, str):
key = "terminated"
else:
key = (*terminated_key[:-1], "terminated")
terminated_keys.append(key)
self.__dict__["_terminated_keys"] = terminated_keys
return terminated_keys

@property
def step_count_keys(self):
step_count_keys = self.__dict__.get("_step_count_keys", None)
Expand Down Expand Up @@ -4495,11 +4510,11 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
for step_count_key, truncated_key, done_key, done_list_sorted in zip(
for step_count_key, truncated_key, done_key, terminated_key in zip(
self.step_count_keys,
self.truncated_keys,
self.done_keys,
self.done_keys_groups,
self.terminated_keys,
):
step_count = tensordict.get(step_count_key)
next_step_count = step_count + 1
Expand All @@ -4508,10 +4523,9 @@ def _step(
truncated = next_step_count >= self.max_steps
if self.update_done:
done = next_tensordict.get(done_key, None)
if done is None:
done = False
for done in done_list_sorted:
done = done | next_tensordict.get(done_key, default=False)
terminated = next_tensordict.get(terminated_key, None)
if terminated is not None:
truncated = truncated & ~terminated
done = truncated | done # we assume no done after reset
next_tensordict.set(done_key, done)
next_tensordict.set(truncated_key, truncated)
Expand Down Expand Up @@ -5734,8 +5748,8 @@ def _step(
) -> TensorDictBase:
# save the final info
done = False
# TODO: check if there's a done, and if there is, get it
for done_key in self.done_keys:
# we assume dones can be broadcast
done = done | next_tensordict.get(done_key)
if done is False:
raise RuntimeError(
Expand Down Expand Up @@ -5808,16 +5822,16 @@ def done_keys(self) -> List[NestedKey]:
keys = self.__dict__.get("_done_keys", None)
if keys is None:
keys = self.parent.done_keys
self._done_keys = keys
expected_done_keys = {"done", "truncated", "terminated"}
# put this check for now. We can consider relaxing that later
# and allow nested values, though they will still need to be unique.
for done_key in keys:
if done_key not in expected_done_keys:
raise RuntimeError(
f"VecGymEnvTransform only supports the following "
f"done keys: {expected_done_keys}, but it got {done_key}."
)
# we just want the "done" key
_done_keys = []
for key in keys:
if not isinstance(key, tuple):
key = (key,)
if key[-1] == "done":
_done_keys.append(unravel_key(key))
if not len(_done_keys):
raise RuntimeError("Could not find a 'done' key in the env specs.")
self._done_keys = _done_keys
return keys

@property
Expand Down
Loading