diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 6ef07b6c93a..7157e174092 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -153,7 +153,9 @@ def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[Lis # necessary to run a TensorDictModule. If a key is an intermediary in # the chain, there is no reason why it should belong to the input # TensorDict. - in_keys += [key for key in module.in_keys if key not in out_keys + in_keys] + for in_key in module.in_keys: + if in_key not in (out_keys + in_keys): + in_keys.append(in_key) out_keys += module.out_keys out_keys = [ @@ -271,12 +273,20 @@ def _run_module( if not self.partial_tolerant or all( key in tensordict_keys for key in module.in_keys ): - tensordict = module(tensordict, params=params, buffers=buffers, **kwargs) + if params is not None or buffers is not None: + tensordict = module( + tensordict, params=params, buffers=buffers, **kwargs + ) + else: + tensordict = module(tensordict, **kwargs) elif self.partial_tolerant and isinstance(tensordict, LazyStackedTensorDict): for sub_td in tensordict.tensordicts: tensordict_keys = set(sub_td.keys()) if all(key in tensordict_keys for key in module.in_keys): - module(sub_td, params=params, buffers=buffers, **kwargs) + if params is not None or buffers is not None: + module(sub_td, params=params, buffers=buffers, **kwargs) + else: + module(sub_td, **kwargs) tensordict._update_valid_keys() return tensordict @@ -348,7 +358,10 @@ def __len__(self): return len(self.module) def __getitem__(self, index: Union[int, slice]) -> TensorDictModule: - return self.module.__getitem__(index) + if isinstance(index, int): + return self.module.__getitem__(index) + else: + return TensorDictSequential(*self.module.__getitem__(index)) def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None: return self.module.__setitem__(idx=index, module=tensordict_module)