Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor calls to get without default that raise KeyError #2353

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,12 +1085,10 @@ def _get_stop_and_length(self, storage, fallback=True):
# We first try with the traj_key
try:
if isinstance(storage, TensorStorage):
trajectory = storage[:].get(self._used_traj_key)
trajectory = storage[:][self._used_traj_key]
else:
try:
trajectory = storage[:].get(self.traj_key)
except KeyError:
raise
trajectory = storage[:][self.traj_key]
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
Expand All @@ -1112,9 +1110,7 @@ def _get_stop_and_length(self, storage, fallback=True):
else:
try:
try:
done = storage[:].get(self.end_key)
except KeyError:
raise
done = storage[:][self.end_key]
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
Expand Down
5 changes: 2 additions & 3 deletions torchrl/envs/transforms/gym_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def _step(self, tensordict, next_tensordict):
end_of_life = torch.as_tensor(
tensordict.get(self.lives_key) > lives, device=self.parent.device
)
try:
done = next_tensordict.get(self.done_key)
except KeyError:
done = next_tensordict.get(self.done_key, None) # TODO: None soon to be removed
if done is None:
raise KeyError(
f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. "
f"Make sure to pass the appropriate done_key to the {type(self)} transform."
Expand Down
49 changes: 26 additions & 23 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,41 +408,44 @@ def _log_probs(

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
except KeyError:
old_state_value = tensordict.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if old_state_value is None:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
f"Make sure that the value_key passed to A2C exists in the input tensordict."
)

try:
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(
tensordict_select,
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
old_state_value = old_state_value.clone()

# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
target_return = tensordict.get(
self.tensor_keys.value_target, None
) # TODO: None soon to be removed
if target_return is None:
raise KeyError(
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
f"return) has been retrieved accordingly. Advantage classes such as GAE, "
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
f"can be used for the value loss."
)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(
tensordict_select,
).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
clip_fraction = None
if self.clip_value:
loss_value, clip_fraction = _clip_value_loss(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1]
with self.mixer_network_params.to_module(self.mixer_network):
self.mixer_network(td_copy)
pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1)
pred_val_index = td_copy[self.tensor_keys.global_value].squeeze(-1)
# [*B] this is global and shared among the agents as will be the target

target_value = self.value_estimator.value_estimate(
Expand Down
21 changes: 12 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# overhead that we could easily reduce.
if self.separate_losses:
tensordict = tensordict.detach()
try:
target_return = tensordict.get(self.tensor_keys.value_target)
except KeyError:
target_return = tensordict.get(
self.tensor_keys.value_target, None
) # TODO: None soon to be removed
if target_return is None:
raise KeyError(
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
Expand All @@ -494,9 +495,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
)

if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value)
except KeyError:
old_state_value = tensordict.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if old_state_value is None:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
Expand All @@ -508,9 +510,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
) if self.functional else contextlib.nullcontext():
state_value_td = self.critic_network(tensordict)

try:
state_value = state_value_td.get(self.tensor_keys.value)
except KeyError:
state_value = state_value_td.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if state_value is None:
raise KeyError(
f"the key {self.tensor_keys.value} was not found in the critic output tensordict. "
f"Make sure that the value_key passed to PPO is accurate."
Expand Down
43 changes: 23 additions & 20 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,32 +398,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:

if self.clip_value:
try:
old_state_value = tensordict.get(self.tensor_keys.value).clone()
except KeyError:
old_state_value = tensordict.get(
self.tensor_keys.value, None
) # TODO: None soon to be removed
if old_state_value is None:
raise KeyError(
f"clip_value is set to {self.clip_value}, but "
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
f"Make sure that the value_key passed to Reinforce exists in the input tensordict."
)
old_state_value = old_state_value.clone()

try:
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(tensordict_select).get(
self.tensor_keys.value
)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
except KeyError:
target_return = tensordict.get(
self.tensor_keys.value_target, None
) # TODO: None soon to be removed
if target_return is None:
raise KeyError(
f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
f"Make sure you provided the right key and the value_target (i.e. the target "
Expand All @@ -432,6 +421,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)

tensordict_select = tensordict.select(
*self.critic_network.in_keys, strict=False
)
with self.critic_network_params.to_module(
self.critic_network
) if self.functional else contextlib.nullcontext():
state_value = self.critic_network(tensordict_select).get(
self.tensor_keys.value
)
loss_value = distance_loss(
target_return,
state_value,
loss_function=self.loss_critic_type,
)
clip_fraction = None
if self.clip_value:
loss_value, clip_fraction = _clip_value_loss(
Expand Down
Loading