Skip to content

Commit

Permalink
[Refactor] Refactor calls to get without default that raise KeyError (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 4, 2024
1 parent 829a9a2 commit 35a1c5b
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 63 deletions.
10 changes: 3 additions & 7 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,12 +1085,10 @@ def _get_stop_and_length(self, storage, fallback=True):
# We first try with the traj_key
try:
if isinstance(storage, TensorStorage):
trajectory = storage[:].get(self._used_traj_key)
trajectory = storage[:][self._used_traj_key]
else:
try:
trajectory = storage[:].get(self.traj_key)
except KeyError:
raise
trajectory = storage[:][self.traj_key]
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
Expand All @@ -1112,9 +1110,7 @@ def _get_stop_and_length(self, storage, fallback=True):
else:
try:
try:
done = storage[:].get(self.end_key)
except KeyError:
raise
done = storage[:][self.end_key]
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
Expand Down
5 changes: 2 additions & 3 deletions torchrl/envs/transforms/gym_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def _step(self, tensordict, next_tensordict):
end_of_life = torch.as_tensor(
tensordict.get(self.lives_key) > lives, device=self.parent.device
)
try:
done = next_tensordict.get(self.done_key)
except KeyError:
done = next_tensordict.get(self.done_key, None) # TODO: None soon to be removed
if done is None:
raise KeyError(
f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. "
f"Make sure to pass the appropriate done_key to the {type(self)} transform."
Expand Down
49 changes: 26 additions & 23 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,41 +408,44 @@ def _log_probs(

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
except KeyError:
old_state_value = tensordict.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if old_state_value is None:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
f"Make sure that the value_key passed to A2C exists in the input tensordict."
)

try:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(
tensordict_select,
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
old_state_value = old_state_value.clone()

# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(
self.tensor_keys.value_target, None
) # TODO: None soon to be removed
if target_return is None:
raise KeyError(
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
f"can be used for the value loss."
)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(
tensordict_select,
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
clip_fraction = None
if self.clip_value:
loss_value, clip_fraction = _clip_value_loss(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1]
with self.mixer_network_params.to_module(self.mixer_network):
self.mixer_network(td_copy)
pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1)
pred_val_index = td_copy[self.tensor_keys.global_value].squeeze(-1)
# [*B] this is global and shared among the agents as will be the target

target_value = self.value_estimator.value_estimate(
Expand Down
21 changes: 12 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# overhead that we could easily reduce.
if self.separate_losses:
tensordict = tensordict.detach()
try:
target_return = tensordict.get(self.tensor_keys.value_target)
except KeyError:
target_return = tensordict.get(
self.tensor_keys.value_target, None
) # TODO: None soon to be removed
if target_return is None:
raise KeyError(
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
Expand All @@ -494,9 +495,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
)

if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value)
except KeyError:
old_state_value = tensordict.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if old_state_value is None:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
Expand All @@ -508,9 +510,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
) if self.functional else contextlib.nullcontext():
state_value_td = self.critic_network(tensordict)

try:
state_value = state_value_td.get(self.tensor_keys.value)
except KeyError:
state_value = state_value_td.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if state_value is None:
raise KeyError(
f"the key {self.tensor_keys.value} was not found in the critic output tensordict. "
f"Make sure that the value_key passed to PPO is accurate."
Expand Down
43 changes: 23 additions & 20 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,32 +398,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:

if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
except KeyError:
old_state_value = tensordict.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if old_state_value is None:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
f"Make sure that the value_key passed to Reinforce exists in the input tensordict."
)
old_state_value = old_state_value.clone()

try:
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(tensordict_select).get(
self.tensor_keys.value
)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
target_return = tensordict.get(
self.tensor_keys.value_target, None
) # TODO: None soon to be removed
if target_return is None:
raise KeyError(
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
Expand All @@ -432,6 +421,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)

tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(tensordict_select).get(
self.tensor_keys.value
)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
clip_fraction = None
if self.clip_value:
loss_value, clip_fraction = _clip_value_loss(
Expand Down

0 comments on commit 35a1c5b

Please sign in to comment.