Skip to content

Commit

Permalink
[BugFix] Minor bugfixes in TensorDictSequential (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 1, 2022
1 parent c644a34 commit 73028e4
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[Lis
# necessary to run a TensorDictModule. If a key is an intermediary in
# the chain, there is no reason why it should belong to the input
# TensorDict.
in_keys += [key for key in module.in_keys if key not in out_keys + in_keys]
for in_key in module.in_keys:
if in_key not in (out_keys + in_keys):
in_keys.append(in_key)
out_keys += module.out_keys

out_keys = [
Expand Down Expand Up @@ -271,12 +273,20 @@ def _run_module(
if not self.partial_tolerant or all(
key in tensordict_keys for key in module.in_keys
):
tensordict = module(tensordict, params=params, buffers=buffers, **kwargs)
if params is not None or buffers is not None:
tensordict = module(
tensordict, params=params, buffers=buffers, **kwargs
)
else:
tensordict = module(tensordict, **kwargs)
elif self.partial_tolerant and isinstance(tensordict, LazyStackedTensorDict):
for sub_td in tensordict.tensordicts:
tensordict_keys = set(sub_td.keys())
if all(key in tensordict_keys for key in module.in_keys):
module(sub_td, params=params, buffers=buffers, **kwargs)
if params is not None or buffers is not None:
module(sub_td, params=params, buffers=buffers, **kwargs)
else:
module(sub_td, **kwargs)
tensordict._update_valid_keys()
return tensordict

Expand Down Expand Up @@ -348,7 +358,10 @@ def __len__(self):
return len(self.module)

def __getitem__(self, index: Union[int, slice]) -> TensorDictModule:
return self.module.__getitem__(index)
if isinstance(index, int):
return self.module.__getitem__(index)
else:
return TensorDictSequential(*self.module.__getitem__(index))

def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None:
return self.module.__setitem__(idx=index, module=tensordict_module)
Expand Down

0 comments on commit 73028e4

Please sign in to comment.