Skip to content

Commit

Permalink
[Feature] optionally set truncated = True at the end of rollouts (pyt…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 27, 2024
1 parent a7bf5a4 commit f439b54
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 8 deletions.
12 changes: 6 additions & 6 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,10 @@ def _reset(self, tensordict):
for done_key in self.done_keys:
if isinstance(done_key, str):
continue
if self.has_root_done:
done = tensordict_reset.get(done_key[-1])
else:
done = tensordict_reset.pop(done_key[-1])
done = tensordict_reset.pop(done_key[-1], None)
if done is None:
continue
tensordict_reset.set(
done_key,
(done.unsqueeze(-2).expand(*self.batch_size, self.nested_dim, 1)),
Expand Down Expand Up @@ -1254,10 +1254,10 @@ def _step(self, tensordict):
for done_key in self.done_keys:
if isinstance(done_key, str):
continue
if self.has_root_done:
done = next_tensordict.get(done_key[-1])
else:
done = next_tensordict.pop(done_key[-1])
done = next_tensordict.pop(done_key[-1], None)
if done is None:
continue
next_tensordict.set(
done_key,
(done.unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)),
Expand Down
31 changes: 31 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,37 @@ def policy(td):
collector.shutdown()


@pytest.mark.parametrize(
"collector_cls",
[SyncDataCollector, MultiSyncDataCollector, MultiaSyncDataCollector],
)
def test_set_truncated(collector_cls):
env_fn = lambda: TransformedEnv(
NestedCountingEnv(), InitTracker()
).add_truncated_keys()
env = env_fn()
policy = env.rand_action
if collector_cls == SyncDataCollector:
collector = collector_cls(
env, policy=policy, frames_per_batch=20, total_frames=-1, set_truncated=True
)
else:
collector = collector_cls(
[env_fn, env_fn],
policy=policy,
frames_per_batch=20,
total_frames=-1,
cat_results="stack",
set_truncated=True,
)
try:
for data in collector:
assert data[..., -1]["next", "data", "truncated"].all()
break
finally:
collector.shutdown()


class TestNestedEnvsCollector:
def test_multi_collector_nested_env_consistency(self, seed=1):
torch.manual_seed(seed)
Expand Down
11 changes: 11 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,17 @@ def test_rollout(env_name, frame_skip, seed=0):
env.close()


def test_rollout_set_truncated():
env = ContinuousActionVecMockEnv()
with pytest.raises(RuntimeError, match="set_truncated was set to True"):
env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
env.add_truncated_keys()
r = env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
assert r.shape == torch.Size([10])
assert r[..., -1]["next", "truncated"].all()
assert r[..., -1]["next", "done"].all()


@pytest.mark.parametrize("max_steps", [1, 5])
def test_rollouts_chaining(max_steps, batch_size=(4,), epochs=4):
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
Expand Down
6 changes: 6 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,12 @@ def format_size(size):
logger.info(indent + os.path.basename(path))


def _ends_with(key, match):
if isinstance(key, str):
return key == match
return key[-1] == match


def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
if isinstance(key, str):
return new_ending
Expand Down
54 changes: 53 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@

from torchrl._utils import (
_check_for_faulty_process,
_ends_with,
_ProcessNoWarn,
_replace_last,
accept_remote_rref_udf_invocation,
logger as torchrl_logger,
prod,
Expand Down Expand Up @@ -346,6 +348,11 @@ class SyncDataCollector(DataCollectorBase):
The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement
strategies such as preeptively stopping rollout collection.
Default is ``False``.
set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -433,6 +440,7 @@ def __init__(
return_same_td: bool = False,
reset_when_done: bool = True,
interruptor=None,
set_truncated: bool = False,
):
from torchrl.envs.batched_envs import BatchedEnvBase

Expand Down Expand Up @@ -770,6 +778,23 @@ def filter_policy(value_output, value_input, value_input_clone):
self.interruptor = interruptor
self._frames = 0
self._iter = -1
self.set_truncated = set_truncated
self._truncated_keys = []
if self.set_truncated:
if not any(
_ends_with(key, "truncated")
for key in self._final_rollout.keys(True, True)
):
raise RuntimeError(
"set_truncated was set to True but no truncated key could be found "
"in the environment. Make sure the truncated keys are properly set using "
"`env.add_truncated_keys()` before passing the env to the collector."
)
self._truncated_keys = [
key
for key in self._final_rollout["next"].keys(True, True)
if _ends_with(key, "truncated")
]

@classmethod
def _get_devices(
Expand Down Expand Up @@ -1038,7 +1063,16 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
return self._final_rollout
return self._maybe_set_truncated(self._final_rollout)

def _maybe_set_truncated(self, final_rollout):
last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,)
for truncated_key in self._truncated_keys:
truncated = final_rollout["next", truncated_key]
truncated[last_step] = True
final_rollout["next", truncated_key] = truncated
final_rollout["next", _replace_last(truncated_key, "done")] = truncated
return final_rollout

@torch.no_grad()
def reset(self, index=None, **kwargs) -> None:
Expand Down Expand Up @@ -1283,6 +1317,12 @@ class _MultiDataCollector(DataCollectorBase):
.. note:: From v0.5, this argument will default to ``"stack"`` for a better
interoperability with the rest of the library.
set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
"""

def __init__(
Expand Down Expand Up @@ -1315,13 +1355,15 @@ def __init__(
num_threads: int = None,
num_sub_threads: int = 1,
cat_results: str | int | None = None,
set_truncated: bool = False,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
self.closed = True
self.num_workers = len(create_env_fn)

self.set_truncated = set_truncated
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self.create_env_fn = create_env_fn
Expand Down Expand Up @@ -1618,6 +1660,7 @@ def _run_processes(self) -> None:
"reset_when_done": self.reset_when_done,
"idx": i,
"interruptor": self.interruptor,
"set_truncated": self.set_truncated,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
Expand Down Expand Up @@ -2493,6 +2536,11 @@ class aSyncDataCollector(MultiaSyncDataCollector):
each subprocess (or one if a single process is launched).
Defaults to 1 for safety: if none is indicated, launching multiple
workers may charge the cpu load too much and harm performance.
set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
"""

Expand Down Expand Up @@ -2525,6 +2573,7 @@ def __init__(
preemptive_threshold: float = None,
num_threads: int = None,
num_sub_threads: int = 1,
set_truncated: bool = False,
**kwargs,
):
super().__init__(
Expand All @@ -2549,6 +2598,7 @@ def __init__(
preemptive_threshold=preemptive_threshold,
num_threads=num_threads,
num_sub_threads=num_sub_threads,
set_truncated=set_truncated,
)

# for RPC
Expand Down Expand Up @@ -2590,6 +2640,7 @@ def _main_async_collector(
reset_when_done: bool = True,
verbose: bool = VERBOSE,
interruptor=None,
set_truncated: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
Expand All @@ -2612,6 +2663,7 @@ def _main_async_collector(
reset_when_done=reset_when_done,
return_same_td=True,
interruptor=interruptor,
set_truncated=set_truncated,
)
if verbose:
torchrl_logger.info("Sync data collector created")
Expand Down
39 changes: 38 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key
from tensordict.base import NO_DEFAULT
from tensordict.utils import NestedKey
from torchrl._utils import _replace_last, implement_for, prod, seed_generator
from torchrl._utils import (
_ends_with,
_replace_last,
implement_for,
prod,
seed_generator,
)

from torchrl.data.tensor_specs import (
CompositeSpec,
Expand Down Expand Up @@ -2280,6 +2286,7 @@ def rollout(
break_when_any_done: bool = True,
return_contiguous: bool = True,
tensordict: Optional[TensorDictBase] = None,
set_truncated: bool = False,
out=None,
):
"""Executes a rollout in the environment.
Expand Down Expand Up @@ -2308,6 +2315,11 @@ def rollout(
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the
output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout.
set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to
``True`` after completion of the rollout. If no ``"truncated"`` is found within the
``done_spec``, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
Returns:
TensorDict object containing the resulting trajectory.
Expand Down Expand Up @@ -2539,9 +2551,34 @@ def rollout(
out_td = LazyStackedTensorDict.lazy_stack(
tensordicts, len(batch_size), out=out
)
if set_truncated:
found_truncated = False
for key in self.done_keys:
if _ends_with(key, "truncated"):
val = out_td.get(("next", key))
val[(slice(None),) * (out_td.ndim - 1) + (-1,)] = True
out_td.set(("next", key), val)
out_td.set(("next", _replace_last(key, "done")), val)
found_truncated = True
if not found_truncated:
raise RuntimeError(
"set_truncated was set to True but no truncated key could be found. "
"Make sure a 'truncated' entry was set in the environment "
"full_done_keys using `env.add_truncated_keys()`."
)

out_td.refine_names(..., "time")
return out_td

def add_truncated_keys(self) -> EnvBase:
"""Adds truncated keys to the environment."""
for key in self.done_keys:
self.full_done_spec[_replace_last(key, "truncated")] = self.full_done_spec[
key
]
self.__dict__["_done_keys"] = None
return self

@property
def _step_mdp(self):
step_func = self.__dict__.get("_step_mdp_value", None)
Expand Down
5 changes: 5 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,11 @@ def batch_size(self, value: torch.Size) -> None:
"Cannot modify the batch-size of a transformed env. Change the batch size of the base_env instead."
)

def add_truncated_keys(self) -> TransformedEnv:
self.base_env.add_truncated_keys()
self.empty_cache()
return self

def _set_env(self, env: EnvBase, device) -> None:
if device != env.device:
env = env.to(device)
Expand Down

0 comments on commit f439b54

Please sign in to comment.