From 8fb55ef99e1d12a83d05bd1a73ddac439d104e9b Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Mon, 29 Jul 2024 21:09:17 +0200 Subject: [PATCH] [Docs] InitTracker cleanup (#2330) --- torchrl/envs/transforms/transforms.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 70aef03e041..afc2313abdd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6420,6 +6420,7 @@ class InitTracker(Transform): Args: init_key (NestedKey, optional): the key to be used for the tracker entry. + In case of multiple _reset flags, this key is used as the leaf replacement for each. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -6433,11 +6434,12 @@ class InitTracker(Transform): """ - def __init__(self, init_key: NestedKey = "is_init"): + def __init__(self, init_key: str = "is_init"): if not isinstance(init_key, str): - raise ValueError("init_key can only be of type str.") + raise ValueError( + "init_key can only be of type str as it will be the leaf key associated to each reset flag." + ) self.init_key = init_key - self.reset_key = "_reset" super().__init__() def set_container(self, container: Union[Transform, EnvBase]) -> None: