diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 70aef03e041..afc2313abdd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6420,6 +6420,7 @@ class InitTracker(Transform): Args: init_key (NestedKey, optional): the key to be used for the tracker entry. + In case of multiple _reset flags, this key is used as the leaf replacement for each. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -6433,11 +6434,12 @@ class InitTracker(Transform): """ - def __init__(self, init_key: NestedKey = "is_init"): + def __init__(self, init_key: str = "is_init"): if not isinstance(init_key, str): - raise ValueError("init_key can only be of type str.") + raise ValueError( + "init_key can only be of type str as it will be the leaf key associated to each reset flag." + ) self.init_key = init_key - self.reset_key = "_reset" super().__init__() def set_container(self, container: Union[Transform, EnvBase]) -> None: