diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index fb8fbff2ccf..a6cb21dd2a4 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -514,32 +514,21 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - shape = None - if tensordict.ndimension() > 1: - shape = tensordict.shape - tensordict_reshape = tensordict.reshape(-1) - else: - tensordict_reshape = tensordict - - q_loss, metadata = self.q_loss(tensordict_reshape) - cql_loss, cql_metadata = self.cql_loss(tensordict_reshape) + q_loss, metadata = self.q_loss(tensordict) + cql_loss, cql_metadata = self.cql_loss(tensordict) if self.with_lagrange: - alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss( - tensordict_reshape - ) + alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict) metadata.update(alpha_prime_metadata) - loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape) - loss_actor, actor_metadata = self.actor_loss(tensordict_reshape) + loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict) + loss_actor, actor_metadata = self.actor_loss(tensordict) loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata) metadata.update(bc_metadata) metadata.update(cql_metadata) metadata.update(actor_metadata) metadata.update(alpha_metadata) - tensordict_reshape.set( + tensordict.set( self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values ) - if shape: - tensordict.update(tensordict_reshape.view(shape)) out = { "loss_actor": loss_actor, "loss_actor_bc": loss_actor_bc, @@ -682,7 +671,9 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): ) # take max over actions state_action_value = state_action_value.reshape( - self.num_qvalue_nets, tensordict.shape[0], self.num_random, -1 + torch.Size( + [self.num_qvalue_nets, *tensordict.shape, self.num_random, -1] + ) ).max(-2)[0] # take min over qvalue nets next_state_value = state_action_value.min(0)[0] @@ -739,14 +730,13 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: "This could be caused by calling cql_loss method before q_loss method." ) - random_actions_tensor = ( - torch.FloatTensor( - tensordict.shape[0] * self.num_random, + random_actions_tensor = pred_q1.new_empty( + ( + *tensordict.shape[:-1], + tensordict.shape[-1] * self.num_random, tensordict[self.tensor_keys.action].shape[-1], ) - .uniform_(-1, 1) - .to(tensordict.device) - ) + ).uniform_(-1, 1) curr_actions_td, curr_log_pis = self._get_policy_actions( tensordict.copy(), self.actor_network_params, @@ -833,7 +823,7 @@ def filter_and_repeat(name, x): q_new[0] - new_log_pis.detach().unsqueeze(-1), q_curr[0] - curr_log_pis.detach().unsqueeze(-1), ], - 1, + -1, ) cat_q2 = torch.cat( [ @@ -841,23 +831,23 @@ def filter_and_repeat(name, x): q_new[1] - new_log_pis.detach().unsqueeze(-1), q_curr[1] - curr_log_pis.detach().unsqueeze(-1), ], - 1, + -1, ) min_qf1_loss = ( - torch.logsumexp(cat_q1 / self.temperature, dim=1) + torch.logsumexp(cat_q1 / self.temperature, dim=-1) * self.min_q_weight * self.temperature ) min_qf2_loss = ( - torch.logsumexp(cat_q2 / self.temperature, dim=1) + torch.logsumexp(cat_q2 / self.temperature, dim=-1) * self.min_q_weight * self.temperature ) # Subtract the log likelihood of data - cql_q1_loss = min_qf1_loss - pred_q1 * self.min_q_weight - cql_q2_loss = min_qf2_loss - pred_q2 * self.min_q_weight + cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight + cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight # write cql losses in tensordict for alpha prime loss tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss) @@ -1080,9 +1070,9 @@ def __init__( self.loss_function = loss_function if action_space is None: # infer from value net - try: + if hasattr(value_network, "action_space"): action_space = value_network.spec - except AttributeError: + else: # let's try with action_space then try: action_space = value_network.action_space @@ -1205,8 +1195,6 @@ def value_loss( with torch.no_grad(): td_error = (pred_val_index - target_value).pow(2) td_error = td_error.unsqueeze(-1) - if tensordict.device is not None: - td_error = td_error.to(tensordict.device) tensordict.set( self.tensor_keys.priority, diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index d86442fca12..cfa5a332df9 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -495,23 +495,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: To see what keys are expected in the input tensordict and what keys are expected as output, check the class's `"in_keys"` and `"out_keys"` attributes. """ - shape = None - if tensordict.ndimension() > 1: - shape = tensordict.shape - tensordict_reshape = tensordict.reshape(-1) - else: - tensordict_reshape = tensordict - - loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape) - loss_actor, metadata_actor = self.actor_loss(tensordict_reshape) + loss_qvalue, value_metadata = self.qvalue_loss(tensordict) + loss_actor, metadata_actor = self.actor_loss(tensordict) loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"]) - tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) + tensordict.set(self.tensor_keys.priority, value_metadata["td_error"]) if loss_actor.shape != loss_qvalue.shape: raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}" ) - if shape: - tensordict.update(tensordict_reshape.view(shape)) entropy = -metadata_actor["log_prob"] out = { "loss_actor": loss_actor, diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index c4639b70bdd..8e30019955b 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -373,16 +373,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - shape = None - if tensordict.ndimension() > 1: - shape = tensordict.shape - tensordict_reshape = tensordict.reshape(-1) - else: - tensordict_reshape = tensordict - - loss_actor, metadata = self.actor_loss(tensordict_reshape) - loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape) - loss_value, metadata_value = self.value_loss(tensordict_reshape) + loss_actor, metadata = self.actor_loss(tensordict) + loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict) + loss_value, metadata_value = self.value_loss(tensordict) metadata.update(metadata_qvalue) metadata.update(metadata_value) @@ -392,13 +385,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}" ) - tensordict_reshape.set( + tensordict.set( self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values ) - if shape: - tensordict.update(tensordict_reshape.view(shape)) - - entropy = -tensordict_reshape.get(self.tensor_keys.log_prob).detach() + entropy = -tensordict.get(self.tensor_keys.log_prob).detach() out = { "loss_actor": loss_actor, "loss_qvalue": loss_qvalue, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 6350538db16..f37b0fba6f0 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -577,30 +577,21 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - shape = None - if tensordict.ndimension() > 1: - shape = tensordict.shape - tensordict_reshape = tensordict.reshape(-1) - else: - tensordict_reshape = tensordict - if self._version == 1: - loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape) - loss_value, _ = self._value_loss(tensordict_reshape) + loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict) + loss_value, _ = self._value_loss(tensordict) else: - loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape) + loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict) loss_value = None - loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_actor, metadata_actor = self._actor_loss(tensordict) loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) - tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) + tensordict.set(self.tensor_keys.priority, value_metadata["td_error"]) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape ): raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}" ) - if shape: - tensordict.update(tensordict_reshape.view(shape)) entropy = -metadata_actor["log_prob"] out = { "loss_actor": loss_actor, @@ -1158,26 +1149,17 @@ def in_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - shape = None - if tensordict.ndimension() > 1: - shape = tensordict.shape - tensordict_reshape = tensordict.reshape(-1) - else: - tensordict_reshape = tensordict - - loss_value, metadata_value = self._value_loss(tensordict_reshape) - loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_value, metadata_value = self._value_loss(tensordict) + loss_actor, metadata_actor = self._actor_loss(tensordict) loss_alpha = self._alpha_loss( log_prob=metadata_actor["log_prob"], ) - tensordict_reshape.set(self.tensor_keys.priority, metadata_value["td_error"]) + tensordict.set(self.tensor_keys.priority, metadata_value["td_error"]) if loss_actor.shape != loss_value.shape: raise RuntimeError( f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}" ) - if shape: - tensordict.update(tensordict_reshape.view(shape)) entropy = -metadata_actor["log_prob"] out = { "loss_actor": loss_actor,