Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024
Prev Previous commit
Next Next commit
docstrings and minor fixes
  • Loading branch information
albertbou92 committed Aug 14, 2024
commit 9ce3c5876b977f1dc004979b929e2ad908a73098
19 changes: 16 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -9238,13 +9238,26 @@ def test_a2c_tensordict_keys_run(
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
@pytest.mark.parametrize("composite_action_dist", [False, ])
@pytest.mark.parametrize(
"composite_action_dist",
[
False,
],
)
def test_a2c_notensordict(
self, action_key, observation_key, reward_key, done_key, terminated_key, composite_action_dist
self,
action_key,
observation_key,
reward_key,
done_key,
terminated_key,
composite_action_dist,
):
torch.manual_seed(self.seed)

actor = self._create_mock_actor(observation_key=observation_key, composite_action_dist=composite_action_dist)
actor = self._create_mock_actor(
observation_key=observation_key, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(observation_key=observation_key)
td = self._create_seq_mock_data_a2c(
action_key=action_key,
Expand Down
16 changes: 14 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict import (
is_tensor_collection,
TensorDict,
TensorDictBase,
TensorDictParams,
)
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand Down Expand Up @@ -190,6 +195,13 @@ class A2CLoss(LossModule):
... next_reward = torch.randn(*batch, 1),
... next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()

.. note::
There is an exception regarding compatibility with non-tensordict-based modules.
If the actor network is probabilistic and uses a `~tensordict.nn.distributions.CompositeDistribution`,
this class must be used with tensordicts and cannot function as a tensordict-independent module.
This is because composite action spaces inherently rely on the structured representation of data provided by
tensordicts to handle their actions.
"""

@dataclass
Expand Down Expand Up @@ -384,7 +396,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, was this a bug or did we sum the log-probs automatically?

Copy link
Contributor Author

@albertbou92 albertbou92 Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply because the log_prob() method for a composite dist will return a TD instead of a Tensor, so we compute the entropy in 2 steps.

This is the old version:

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            entropy = -dist.log_prob(x).mean(0)
        return entropy.unsqueeze(-1)

This is the new version. It simply retrieves the log tensor before computing the entropy.

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            log_prob = dist.log_prob(x)
            if is_tensor_collection(log_prob):
                log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
            entropy = -log_prob.mean(0)
        return entropy.unsqueeze(-1)

if isinstance(log_prob, TensorDict):
if is_tensor_collection(log_prob):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)
Expand Down
17 changes: 14 additions & 3 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict import (
is_tensor_collection,
TensorDict,
TensorDictBase,
TensorDictParams,
)
from tensordict.nn import (
dispatch,
ProbabilisticTensorDictModule,
Expand Down Expand Up @@ -238,6 +243,12 @@ class PPOLoss(LossModule):
... next_observation=torch.randn(*batch, n_obs))
>>> loss_objective.backward()

.. note::
There is an exception regarding compatibility with non-tensordict-based modules.
If the actor network is probabilistic and uses a `~tensordict.nn.distributions.CompositeDistribution`,
this class must be used with tensordicts and cannot function as a tensordict-independent module.
This is because composite action spaces inherently rely on the structured representation of data provided by
tensordicts to handle their actions.
"""

@dataclass
Expand Down Expand Up @@ -450,7 +461,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
if isinstance(log_prob, TensorDict):
if is_tensor_collection(log_prob):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)
Expand Down Expand Up @@ -1119,7 +1130,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
x = previous_dist.sample((self.samples_mc_kl,))
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
if isinstance(x, TensorDict):
if is_tensor_collection(x):
previous_log_prob = previous_log_prob.get(
self.tensor_keys.sample_log_prob
)
Expand Down