Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024

Conversation

albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Aug 12, 2024

Description

At the moment objective classes do not allow to use an actor with a composite distribution.

This PR aims to fix this. I have started with PPO, it turned out to required more changes than I anticipated. In particular, I am struggling with the test test_ppo_notensordict.

Once these modification are correct, I will move on the the tests of the other on-policy objectives and then to all other objectives.

This PR requires the TensorDict PR pytorch/tensordict#961 to be merged.

Copy link

pytorch-bot bot commented Aug 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2391

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 6 Unrelated Failures

As of commit 69922fa with merge base a6310ae (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 12, 2024
@albertbou92 albertbou92 marked this pull request as draft August 12, 2024 18:16
@vmoens vmoens changed the title [BUG] Allow for composite action distributions in losses [BugFix] Allow for composite action distributions in losses Aug 13, 2024
@vmoens vmoens added bug Something isn't working enhancement New feature or request labels Aug 13, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work!
There's an exception with test_ppo_notensordict in test_cost.py

Have a look at the couple of comments I left

torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
@albertbou92
Copy link
Contributor Author

Thanks a lot for the feedback!

I have made a few changes:

  1. I have removed any specific distribution id check isinstance(dist, CompositeDistribution), which should make the code more general. I am instead simply checking whether the log_prob is a TensorDict as suggested, or alternatively if action is a torch.Tensor. I think using a TransformedDistribution to do that is a bit overkill as it requires coding some methods that we don't need, but if you prefer it I can further look into it.
  2. Also added tests for A2C. At the moment, a composite action distribution is simply not compatible with notensordict. The action becomes a nested structure which makes it difficult. Any suggestion how to solve that?

@albertbou92 albertbou92 requested a review from vmoens August 14, 2024 11:39
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to assume that people won't use this feature with non-tensordict inputs, because the action will be a tensordict anyway.
I would just document it properly in the loss docstrings where we explain how to use the loss without tensordict.

@@ -383,26 +383,39 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x).mean(0)
log_prob = dist.log_prob(x)
if isinstance(log_prob, TensorDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lazy stack is not a TensorDict but a TensorDict base.
Also ideally we would want this to work with tensorclasses.
The way to go should be to use is_tensor_collection from tensordict lib.

@@ -449,28 +449,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x).mean(0)
log_prob = dist.log_prob(x)
if isinstance(log_prob, TensorDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0)
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
if isinstance(x, TensorDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to assume that people won't use this feature with non-tensordict inputs, because the action will be a tensordict anyway.
I would just document it properly in the loss docstrings where we explain how to use the loss without tensordict.

@albertbou92 albertbou92 marked this pull request as ready for review August 14, 2024 16:07
@albertbou92
Copy link
Contributor Author

Done! I will do the off-policy and offline losses in separate PRs.

@albertbou92 albertbou92 requested a review from vmoens August 14, 2024 16:11
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
Do we need a simple test for this?
Like a dedicated function in PPOTest that runs it with composite dists?

@@ -383,26 +395,39 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x).mean(0)
log_prob = dist.log_prob(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, was this a bug or did we sum the log-probs automatically?

Copy link
Contributor Author

@albertbou92 albertbou92 Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply because the log_prob() method for a composite dist will return a TD instead of a Tensor, so we compute the entropy in 2 steps.

This is the old version:

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            entropy = -dist.log_prob(x).mean(0)
        return entropy.unsqueeze(-1)

This is the new version. It simply retrieves the log tensor before computing the entropy.

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            log_prob = dist.log_prob(x)
            if is_tensor_collection(log_prob):
                log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
            entropy = -log_prob.mean(0)
        return entropy.unsqueeze(-1)

torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
torchrl/objectives/ppo.py Outdated Show resolved Hide resolved
@albertbou92
Copy link
Contributor Author

Regarding a dedicated test, how do you usually approach this decision?

When I thought about testing different dists I saw it a bit like testing for different types of ValueEstimators. It should work both with single dists and with composite dists in all the tested situations. So I added it to all tests (except the notensordict tests).

We could probably switch to a single dedicated test function though.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for this!

@vmoens vmoens changed the title [BugFix] Allow for composite action distributions in losses [BugFix] Allow for composite action distributions in PPO/A2C losses Aug 28, 2024
@albertbou92
Copy link
Contributor Author

Computing the entropy for composite distributions is not fully resolved, particularly when dealing with a composite distribution that includes some distributions with an implemented entropy() method and others without.

We could add an entropy method to CompositeDistribution as suggested in this PR: pytorch/tensordict#981

wdyt?

@vmoens
Copy link
Contributor

vmoens commented Sep 3, 2024

From discord:

Hey, sorry, was out on Holidays last week. I can have a look at the PR. What I was wondering: Why is the behavior for log_prob different between composite and non-composite distributions? I.e. why not have log_prob always just return the (combined) log-prob and have a separate log_prob_composite function for the specific case of returning individual log-probs in the composite distribution?

I think it's a good idea, I prefer a log_prob which always returns a single tensor.

@vmoens
Copy link
Contributor

vmoens commented Sep 3, 2024

To add on my previous comment, here is how I would address this:

  • Add a warning saying that from v0.7 log-prob will return a tensor. Users can achieve this already by constructing the lib with an additional temporary kwarg return_log_prob_tensor=True (this will be the default in v0.7).
  • Add a method log_prob_composite for those who need it.

@albertbou92
Copy link
Contributor Author

From discord:

Hey, sorry, was out on Holidays last week. I can have a look at the PR. What I was wondering: Why is the behavior for log_prob different between composite and non-composite distributions? I.e. why not have log_prob always just return the (combined) log-prob and have a separate log_prob_composite function for the specific case of returning individual log-probs in the composite distribution?

I think it's a good idea, I prefer a log_prob which always returns a single tensor.

makes sense

@vmoens
Copy link
Contributor

vmoens commented Sep 4, 2024

Merging this to clear space in the PR list but we should take care of #2391 (comment) sooner than later! Wanna give a shot at it or should I?

@vmoens vmoens merged commit 49d7f74 into pytorch:main Sep 4, 2024
67 of 76 checks passed
@vmoens vmoens deleted the loss_composite_dist branch September 4, 2024 14:13
@albertbou92
Copy link
Contributor Author

Merging this to clear space in the PR list but we should take care of #2391 (comment) sooner than later! Wanna give a shot at it or should I?

I will give it a shot, give me a few days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants