-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Discrete SAC rewrite #1461
Conversation
It would also be very cool to utilize the |
I am not stacking params, I am using vmap like in the normal SAC so I guess it behaves exactly like the ensemble module? |
@smorad at a quick look it seems that the losses have their way of making ensembles in https://github.com/pytorch/rl/blob/main/torchrl/objectives/common.py#L174 (which also creates target params). |
Hey @matteobettini, thanks for fixing this for me! It seems to work on my side as well |
Yes they do, as the losses were implemented before the addition of Adding it would be fairly straightforward and simplify the loss function: loss_fn = SACLoss(
q_network=EnsembleModule(MyQNetwork(), num_copies=5)
...
)
class SACLoss:
...
def forward(...):
...
q_values = self.q_network(td)
if isinstance(self.q_network, EnsembleModule):
# Min reduce over q values
q_values['action_value'] = q_values['action_value'].min(0).values
# EnsembleModule will expand at dim 0, so this just squeezes it back
q_values = q_values[0]
... It would open up a lot of interesting options. For example, one could consider using an Ensemble of actors too, and taking the combined mode action in the discrete SAC case. |
Yea ideally it would be cool to pass ensembles in place of each nn module passed to the losses I am not super convinced by the isinstance you called in the loss. Wouldn't there be a seamless way of passing ensembles in place of current modules and manage the dimensionality outside of the loss? smth like what you said here
i feel like losses should explicitly consider ensembles only when they are required by default (as SAC) |
Actually, my above code is indeed wrong, we don't want to reduce at all in the q function loss! Usually, we want to update all ensemble members at once: You could instead just do loss_fn = SACLoss(
q_network=EnsembleModule(MyQFn(), num_copies=5),
q_reduce=MinReduce(reduce_key='action_value')
...
)
class SACLoss:
def forward_q_loss(...):
...
# Of shape [ensemble_size, batch, time, agent, ...]
q_values = self.q_network(td)
# Unsqueeze so this is [1, batch, time, agent, ...]
if self.q_reduce is not None:
target = td_estimator.estimate(td.unsqueeze(0))
error = (q_values - target).pow(2).mean()
return error That code using def forward_actor_loss(...):
# TD is originally of shape [ensemble, batch, ...] from the compute_q_loss
# Now let's reduce it to [batch, ...]
if self.q_reduce is not None:
td = self.q_reduce(td)
# Continue to compute actor loss as normal
... You could easily do min/softmin/mean/median/whatever reduce you want. For the case you don't want an ensemble, just do this: loss_fn = SACLoss(
q_network=MyQFn()
q_reduce=None
...
) |
if action.shape != action_value.shape: | ||
# unsqueeze the action if it lacks on trailing singleton dim | ||
action = action.unsqueeze(-1) | ||
chosen_action_value = torch.gather(action_value, -1, index=action).squeeze( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matteobettini Just noticed a bug here when I was testing in the categorical case. So in my setup, I have 2 qvalue nets. This causes my action_value to have a shape of: torch.Size([2, 256, 1, 5])
. However, my action is just of the shape torch.Size([256, 1, 1])
. Note that it has the extra dimension at the end because of the unsqueeze above. However, this causes the gather
to fail as the index
and input
tensors have mismatching dimensions. Using the following code below seemed to work for me:
action_idx = action.expand(torch.Size([action_value.shape[0], *action.shape]))
chosen_action_value = torch.gather(action_value, -1, index=action_idx).squeeze(-1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks so much for spotting this, phew.
can you tell me if now it works in both one hot and categorical cases for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it works for me!
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some minor comments
"loss_actor": loss_actor.mean(), | ||
"loss_qvalue": loss_value.mean(), | ||
"loss_alpha": loss_alpha.mean(), | ||
"alpha": self._alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we detach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a property, if you see it should be already detached (i follow what you wrote in the other sac)
@property
def _alpha(self):
if self.min_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
return alpha
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
examples/multiagent/sac.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this example as a bonus
Signed-off-by: Matteo Bettini <matbet@meta.com>
Signed-off-by: Matteo Bettini <matbet@meta.com>
@smorad happy to consider this in the future. We should open an issue about it |
Signed-off-by: Matteo Bettini <matbet@meta.com>
Fixes #1459.
This PR rewrites discrete SAC, enabling the following features:
_compute_target
,_actor_loss
,_value_loss
modular subfunctions@vmoens @smorad @hyerra