diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 8d594fc91a3..582ac88f52d 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1085,12 +1085,10 @@ def _get_stop_and_length(self, storage, fallback=True): # We first try with the traj_key try: if isinstance(storage, TensorStorage): - trajectory = storage[:].get(self._used_traj_key) + trajectory = storage[:][self._used_traj_key] else: try: - trajectory = storage[:].get(self.traj_key) - except KeyError: - raise + trajectory = storage[:][self.traj_key] except Exception: raise RuntimeError( "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." @@ -1112,9 +1110,7 @@ def _get_stop_and_length(self, storage, fallback=True): else: try: try: - done = storage[:].get(self.end_key) - except KeyError: - raise + done = storage[:][self.end_key] except Exception: raise RuntimeError( "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index 99f38ebb32c..35f122b770a 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -150,9 +150,8 @@ def _step(self, tensordict, next_tensordict): end_of_life = torch.as_tensor( tensordict.get(self.lives_key) > lives, device=self.parent.device ) - try: - done = next_tensordict.get(self.done_key) - except KeyError: + done = next_tensordict.get(self.done_key, None) # TODO: None soon to be removed + if done is None: raise KeyError( f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. " f"Make sure to pass the appropriate done_key to the {type(self)} transform." diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 4a0948e1bca..a236b80d56c 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -408,34 +408,23 @@ def _log_probs( def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: if self.clip_value: - try: - old_state_value = tensordict.get(self.tensor_keys.value).clone() - except KeyError: + old_state_value = tensordict.get( + self.tensor_keys.value, None + ) # TODO: None soon to be removed + if old_state_value is None: raise KeyError( f"clip_value is set to {self.clip_value}, but " f"the key {self.tensor_keys.value} was not found in the input tensordict. " f"Make sure that the value_key passed to A2C exists in the input tensordict." ) - - try: - # TODO: if the advantage is gathered by forward, this introduces an - # overhead that we could easily reduce. - target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select( - *self.critic_network.in_keys, strict=False - ) - with self.critic_network_params.to_module( - self.critic_network - ) if self.functional else contextlib.nullcontext(): - state_value = self.critic_network( - tensordict_select, - ).get(self.tensor_keys.value) - loss_value = distance_loss( - target_return, - state_value, - loss_function=self.loss_critic_type, - ) - except KeyError: + old_state_value = old_state_value.clone() + + # TODO: if the advantage is gathered by forward, this introduces an + # overhead that we could easily reduce. + target_return = tensordict.get( + self.tensor_keys.value_target, None + ) # TODO: None soon to be removed + if target_return is None: raise KeyError( f"the key {self.tensor_keys.value_target} was not found in the input tensordict. " f"Make sure you provided the right key and the value_target (i.e. the target " @@ -443,6 +432,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that " f"can be used for the value loss." ) + tensordict_select = tensordict.select( + *self.critic_network.in_keys, strict=False + ) + with self.critic_network_params.to_module( + self.critic_network + ) if self.functional else contextlib.nullcontext(): + state_value = self.critic_network( + tensordict_select, + ).get(self.tensor_keys.value) + loss_value = distance_loss( + target_return, + state_value, + loss_function=self.loss_critic_type, + ) clip_fraction = None if self.clip_value: loss_value, clip_fraction = _clip_value_loss( diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f843a3659b5..c9dc281ef41 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -355,7 +355,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1] with self.mixer_network_params.to_module(self.mixer_network): self.mixer_network(td_copy) - pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1) + pred_val_index = td_copy[self.tensor_keys.global_value].squeeze(-1) # [*B] this is global and shared among the agents as will be the target target_value = self.value_estimator.value_estimate( diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 20dfab3d6b4..c29bc73dfa8 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -482,9 +482,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # overhead that we could easily reduce. if self.separate_losses: tensordict = tensordict.detach() - try: - target_return = tensordict.get(self.tensor_keys.value_target) - except KeyError: + target_return = tensordict.get( + self.tensor_keys.value_target, None + ) # TODO: None soon to be removed + if target_return is None: raise KeyError( f"the key {self.tensor_keys.value_target} was not found in the input tensordict. " f"Make sure you provided the right key and the value_target (i.e. the target " @@ -494,9 +495,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: ) if self.clip_value: - try: - old_state_value = tensordict.get(self.tensor_keys.value) - except KeyError: + old_state_value = tensordict.get( + self.tensor_keys.value, None + ) # TODO: None soon to be removed + if old_state_value is None: raise KeyError( f"clip_value is set to {self.clip_value}, but " f"the key {self.tensor_keys.value} was not found in the input tensordict. " @@ -508,9 +510,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: ) if self.functional else contextlib.nullcontext(): state_value_td = self.critic_network(tensordict) - try: - state_value = state_value_td.get(self.tensor_keys.value) - except KeyError: + state_value = state_value_td.get( + self.tensor_keys.value, None + ) # TODO: None soon to be removed + if state_value is None: raise KeyError( f"the key {self.tensor_keys.value} was not found in the critic output tensordict. " f"Make sure that the value_key passed to PPO is accurate." diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 3d867b8cb99..af9f7d99b46 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -398,32 +398,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: if self.clip_value: - try: - old_state_value = tensordict.get(self.tensor_keys.value).clone() - except KeyError: + old_state_value = tensordict.get( + self.tensor_keys.value, None + ) # TODO: None soon to be removed + if old_state_value is None: raise KeyError( f"clip_value is set to {self.clip_value}, but " f"the key {self.tensor_keys.value} was not found in the input tensordict. " f"Make sure that the value_key passed to Reinforce exists in the input tensordict." ) + old_state_value = old_state_value.clone() - try: - target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select( - *self.critic_network.in_keys, strict=False - ) - with self.critic_network_params.to_module( - self.critic_network - ) if self.functional else contextlib.nullcontext(): - state_value = self.critic_network(tensordict_select).get( - self.tensor_keys.value - ) - loss_value = distance_loss( - target_return, - state_value, - loss_function=self.loss_critic_type, - ) - except KeyError: + target_return = tensordict.get( + self.tensor_keys.value_target, None + ) # TODO: None soon to be removed + if target_return is None: raise KeyError( f"the key {self.tensor_keys.value_target} was not found in the input tensordict. " f"Make sure you provided the right key and the value_target (i.e. the target " @@ -432,6 +421,20 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"can be used for the value loss." ) + tensordict_select = tensordict.select( + *self.critic_network.in_keys, strict=False + ) + with self.critic_network_params.to_module( + self.critic_network + ) if self.functional else contextlib.nullcontext(): + state_value = self.critic_network(tensordict_select).get( + self.tensor_keys.value + ) + loss_value = distance_loss( + target_return, + state_value, + loss_function=self.loss_critic_type, + ) clip_fraction = None if self.clip_value: loss_value, clip_fraction = _clip_value_loss(