From e76d8cbfc3bb000c159b53b373d9ee33b9199bd7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Aug 2024 13:02:06 +0100 Subject: [PATCH] [BugFix] Fix get-related errors (#2361) --- torchrl/data/tensor_specs.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- torchrl/modules/tensordict_module/rnn.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7c787b3ccfc..7f94ae80aeb 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4121,7 +4121,7 @@ def is_in(self, val: Union[dict, TensorDictBase]) -> bool: for key, item in self._specs.items(): if item is None or (isinstance(item, CompositeSpec) and item.is_empty()): continue - val_item = val.get(key) + val_item = val[key] if not item.is_in(val_item): return False return True diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7c9dec980f5..255af86a61e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5601,7 +5601,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) time_dim = time_dim[0] - 1 for in_key, out_key in zip(self.in_keys, self.out_keys): - reward = tensordict.get(in_key) + reward = tensordict[in_key] cumsum = reward.cumsum(time_dim) tensordict.set(out_key, cumsum) return tensordict diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 878fb13ebb8..048ddedbf9d 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -665,7 +665,7 @@ def forward(self, tensordict: TensorDictBase): else: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated @@ -679,7 +679,7 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped = _split_and_pad_sequence( tensordict_shaped.select(*self.in_keys, strict=False), splits ) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) value, hidden0, hidden1 = ( tensordict_shaped.get(key, default) @@ -1410,7 +1410,7 @@ def forward(self, tensordict: TensorDictBase): else: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated @@ -1424,7 +1424,7 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped = _split_and_pad_sequence( tensordict_shaped.select(*self.in_keys, strict=False), splits ) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) value, hidden = ( tensordict_shaped.get(key, default)