Skip to content

Commit

Permalink
[BugFix] Fix get-related errors (#2361)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 5, 2024
1 parent 35a1c5b commit e76d8cb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4121,7 +4121,7 @@ def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
for key, item in self._specs.items():
if item is None or (isinstance(item, CompositeSpec) and item.is_empty()):
continue
val_item = val.get(key)
val_item = val[key]
if not item.is_in(val_item):
return False
return True
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5601,7 +5601,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
time_dim = time_dim[0] - 1
for in_key, out_key in zip(self.in_keys, self.out_keys):
reward = tensordict.get(in_key)
reward = tensordict[in_key]
cumsum = reward.cumsum(time_dim)
tensordict.set(out_key, cumsum)
return tensordict
Expand Down
8 changes: 4 additions & 4 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def forward(self, tensordict: TensorDictBase):
else:
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)

is_init = tensordict_shaped.get("is_init").squeeze(-1)
is_init = tensordict_shaped["is_init"].squeeze(-1)
splits = None
if self.recurrent_mode and is_init[..., 1:].any():
# if we have consecutive trajectories, things get a little more complicated
Expand All @@ -679,7 +679,7 @@ def forward(self, tensordict: TensorDictBase):
tensordict_shaped = _split_and_pad_sequence(
tensordict_shaped.select(*self.in_keys, strict=False), splits
)
is_init = tensordict_shaped.get("is_init").squeeze(-1)
is_init = tensordict_shaped["is_init"].squeeze(-1)

value, hidden0, hidden1 = (
tensordict_shaped.get(key, default)
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def forward(self, tensordict: TensorDictBase):
else:
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)

is_init = tensordict_shaped.get("is_init").squeeze(-1)
is_init = tensordict_shaped["is_init"].squeeze(-1)
splits = None
if self.recurrent_mode and is_init[..., 1:].any():
# if we have consecutive trajectories, things get a little more complicated
Expand All @@ -1424,7 +1424,7 @@ def forward(self, tensordict: TensorDictBase):
tensordict_shaped = _split_and_pad_sequence(
tensordict_shaped.select(*self.in_keys, strict=False), splits
)
is_init = tensordict_shaped.get("is_init").squeeze(-1)
is_init = tensordict_shaped["is_init"].squeeze(-1)

value, hidden = (
tensordict_shaped.get(key, default)
Expand Down

0 comments on commit e76d8cb

Please sign in to comment.