diff --git a/test/test_cost.py b/test/test_cost.py index ab95c55ef83..b11cec924e3 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7565,6 +7565,7 @@ def _create_mock_actor( "action1": (action_key, "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -7634,6 +7635,7 @@ def _create_mock_actor_value( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -7690,6 +7692,7 @@ def _create_mock_actor_value_shared( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8627,6 +8630,7 @@ def _create_mock_actor( "action1": (action_key, "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8727,6 +8731,7 @@ def _create_mock_common_layer_setup( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index ff9b5f3883e..34c62bc3260 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -420,8 +420,13 @@ def _log_probs( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - tensordict = dist.log_prob(tensordict) - log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + maybe_log_prob = dist.log_prob(tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not + # be a tensor + log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = maybe_log_prob log_prob = log_prob.unsqueeze(-1) return log_prob, dist diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b4779a90663..9d9790ab294 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -490,8 +490,13 @@ def _log_weight( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - tensordict = dist.log_prob(tensordict) - log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + maybe_log_prob = dist.log_prob(tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not + # be a tensor + log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = maybe_log_prob log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -1130,7 +1135,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: x = previous_dist.sample((self.samples_mc_kl,)) previous_log_prob = previous_dist.log_prob(x) current_log_prob = current_dist.log_prob(x) - if is_tensor_collection(x): + if is_tensor_collection(current_log_prob): previous_log_prob = previous_log_prob.get( self.tensor_keys.sample_log_prob )