Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Discrete SAC rewrite #1461

Merged
merged 18 commits into from
Aug 30, 2023
Merged

[BugFix] Discrete SAC rewrite #1461

merged 18 commits into from
Aug 30, 2023

Conversation

matteobettini
Copy link
Contributor

@matteobettini matteobettini commented Aug 14, 2023

Fixes #1459.

This PR rewrites discrete SAC, enabling the following features:

  • support for more discrete spaces
  • support for data with batch size
  • following the modular implementation of continuous sac with: _compute_target , _actor_loss, _value_loss modular subfunctions
  • readability
  • grouping utils that work with discrete spaces (in dqn, discsac, actors,..) into torchrl.data.utils

@vmoens @smorad @hyerra

Signed-off-by: Matteo Bettini <matbet@meta.com>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 14, 2023
Signed-off-by: Matteo Bettini <matbet@meta.com>
@smorad
Copy link
Contributor

smorad commented Aug 14, 2023

It would also be very cool to utilize the EnsembleModule here instead of the previous approach of stacking network params. This will ensure the ensemble members are initialized correctly and give a bit more power to the user.

@matteobettini
Copy link
Contributor Author

matteobettini commented Aug 14, 2023

I am not stacking params, I am using vmap like in the normal SAC so I guess it behaves exactly like the ensemble module?
lemme see if I can swap it in

Signed-off-by: Matteo Bettini <matbet@meta.com>
@matteobettini
Copy link
Contributor Author

@smorad at a quick look it seems that the losses have their way of making ensembles in https://github.com/pytorch/rl/blob/main/torchrl/objectives/common.py#L174 (which also creates target params).
so I think that if we want to use EnsembleModule the best place would be to integrate it there in the future.
for now i think I am just going to stick to normal vmap as it is used in the other losses (like sac).

Signed-off-by: Matteo Bettini <matbet@meta.com>
@hyerra
Copy link
Contributor

hyerra commented Aug 14, 2023

Hey @matteobettini, thanks for fixing this for me! It seems to work on my side as well

@matteobettini matteobettini marked this pull request as ready for review August 14, 2023 16:05
@smorad
Copy link
Contributor

smorad commented Aug 15, 2023

@smorad at a quick look it seems that the losses have their way of making ensembles in https://github.com/pytorch/rl/blob/main/torchrl/objectives/common.py#L174 (which also creates target params).

Yes they do, as the losses were implemented before the addition of EnsembleModule. @vmoens and I briefly discussed this in this issue as a more generic form of what is currently done in the losses. It is also vmapped so it should be very fast.

Adding it would be fairly straightforward and simplify the loss function:

loss_fn = SACLoss(
  q_network=EnsembleModule(MyQNetwork(), num_copies=5)
  ...
)

class SACLoss:
...
def forward(...):
  ...
  q_values = self.q_network(td)
  if isinstance(self.q_network, EnsembleModule):
    # Min reduce over q values
    q_values['action_value'] = q_values['action_value'].min(0).values
    # EnsembleModule will expand at dim 0, so this just squeezes it back
    q_values = q_values[0]
 ...

It would open up a lot of interesting options. For example, one could consider using an Ensemble of actors too, and taking the combined mode action in the discrete SAC case.

@matteobettini
Copy link
Contributor Author

matteobettini commented Aug 15, 2023

Yea ideally it would be cool to pass ensembles in place of each nn module passed to the losses

I am not super convinced by the isinstance you called in the loss.

Wouldn't there be a seamless way of passing ensembles in place of current modules and manage the dimensionality outside of the loss?

smth like what you said here

policy = TensorDictSequential(
  Ensemble(in_keys=['observation'], out_keys=['ensemble_state_action_value']),
  Reduce(in_keys=['ensemble_state_action_value'], out_keys=['state_value_action'] reduce_fn=lambda x, dim: x.min(dim=dim),
  QValueActor(env.action_spec)
)

i feel like losses should explicitly consider ensembles only when they are required by default (as SAC)
if you want to use them for other losses (like combining actor ensembles) you can do it in your module like in the snippet above

@smorad
Copy link
Contributor

smorad commented Aug 15, 2023

Actually, my above code is indeed wrong, we don't want to reduce at all in the q function loss! Usually, we want to update all ensemble members at once:

You could instead just do

loss_fn = SACLoss(
  q_network=EnsembleModule(MyQFn(), num_copies=5),
  q_reduce=MinReduce(reduce_key='action_value')
  ...
)

class SACLoss:
  def forward_q_loss(...):
    ...
    # Of shape [ensemble_size, batch, time, agent, ...]
    q_values = self.q_network(td)
    # Unsqueeze so this is [1, batch, time, agent, ...]
    if self.q_reduce is not None:
      target = td_estimator.estimate(td.unsqueeze(0))
    error = (q_values - target).pow(2).mean()
    return error

That code using Reduce would be used when training the actor and in the agent passed to the collector. There, we just want the "best" Q value estimate instead of all of them.

def forward_actor_loss(...):
  # TD is originally of shape [ensemble, batch, ...] from the compute_q_loss
  # Now let's reduce it to [batch, ...]
  if self.q_reduce is not None:
    td = self.q_reduce(td)
  # Continue to compute actor loss as normal
  ...

You could easily do min/softmin/mean/median/whatever reduce you want. For the case you don't want an ensemble, just do this:

loss_fn = SACLoss(
  q_network=MyQFn()
  q_reduce=None
  ...
)

if action.shape != action_value.shape:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
chosen_action_value = torch.gather(action_value, -1, index=action).squeeze(
Copy link
Contributor

@hyerra hyerra Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matteobettini Just noticed a bug here when I was testing in the categorical case. So in my setup, I have 2 qvalue nets. This causes my action_value to have a shape of: torch.Size([2, 256, 1, 5]). However, my action is just of the shape torch.Size([256, 1, 1]). Note that it has the extra dimension at the end because of the unsqueeze above. However, this causes the gather to fail as the index and input tensors have mismatching dimensions. Using the following code below seemed to work for me:

action_idx = action.expand(torch.Size([action_value.shape[0], *action.shape]))
chosen_action_value = torch.gather(action_value, -1, index=action_idx).squeeze(-1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for spotting this, phew.
can you tell me if now it works in both one hot and categorical cases for you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, it works for me!

Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some minor comments

"loss_actor": loss_actor.mean(),
"loss_qvalue": loss_value.mean(),
"loss_alpha": loss_alpha.mean(),
"alpha": self._alpha,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we detach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a property, if you see it should be already detached (i follow what you wrote in the other sac)


   @property
    def _alpha(self):
        if self.min_log_alpha is not None:
            self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
        with torch.no_grad():
            alpha = self.log_alpha.exp()
        return alpha

torchrl/objectives/sac.py Outdated Show resolved Hide resolved
torchrl/objectives/sac.py Outdated Show resolved Hide resolved
torchrl/objectives/sac.py Show resolved Hide resolved
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this example as a bonus

Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
@matteobettini matteobettini marked this pull request as draft August 26, 2023 21:08
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
@matteobettini matteobettini marked this pull request as ready for review August 27, 2023 09:32
@matteobettini
Copy link
Contributor Author

I have run the discrete sac example available on main to compare main vs this PR.

The reported metrics seems the same
Screenshot 2023-08-27 at 10 36 11

@vmoens vmoens added bug Something isn't working Refactoring Refactoring of an existing feature labels Aug 30, 2023
@vmoens vmoens merged commit 78b2bb2 into pytorch:main Aug 30, 2023
@matteobettini matteobettini deleted the fix-sac branch August 30, 2023 06:59
@vmoens
Copy link
Contributor

vmoens commented Aug 30, 2023

It would also be very cool to utilize the EnsembleModule here instead of the previous approach of stacking network params. This will ensure the ensemble members are initialized correctly and give a bit more power to the user.

@smorad happy to consider this in the future. We should open an issue about it

vmoens pushed a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Signed-off-by: Matteo Bettini <matbet@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Refactoring Refactoring of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Using DiscreteSAC in Multi-Agent environments
5 participants