-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Algorithm] Discrete CQL #1666
[Algorithm] Discrete CQL #1666
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I left some high level comments, can you have a look?
Thanks for this
torchrl/objectives/cql.py
Outdated
logsumexp = torch.logsumexp(q_values, dim=-1, keepdim=True) | ||
q_a = (q_values * current_action).sum(dim=-1, keepdim=True) | ||
|
||
return (logsumexp - q_a).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we return metadata too, like we're hoping to do for all losses in the future?
torchrl/objectives/cql.py
Outdated
self._in_keys = values | ||
|
||
@dispatch | ||
def forward(self, tensordict: TensorDictBase) -> TensorDict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should just be a couple of lines with dqn_loss and cql_loss IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried to inherit from the DQN class and then do something like super.forward(tensordict) and only have the cql_loss calculation added but I got circular importing issues. Do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I wasn't suggesting to inherit from DQN, it's ok if they're separated. But the forward should just be a composition of loss_actor
and loss_critic
like we did in other losses (eg, TD3), where each sub-loss returns a tensor and a dict of metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, got it! Should be adapted accordingly now.
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand it.
CQL loss is called in value loss, not in forward, why is that?
Why do we call item() on CQL loss value? Conventionally all losses in the output tensordict of a loss module should be differentials.
Can you give me some context?
The CQL loss is more like an auxiliary term for the value loss not for a separate model like the actor. It just augments the value loss. We could separate it but then we would need to forward pass through the model again to obtain the current q values, which would slow down the process and I think there is no need to obtain only the cql loss as in itself it's incomplete. |
torchrl/objectives/cql.py
Outdated
cql_loss = self.cql_loss(pred_val, action) | ||
|
||
# calculate target value | ||
with torch.no_grad(): | ||
target_value = self.value_estimator.value_estimate( | ||
td_copy, | ||
target_params=self._cached_detached_target_value_params, | ||
).squeeze(-1) | ||
|
||
with torch.no_grad(): | ||
td_error = (pred_val_index - target_value).pow(2) | ||
td_error = td_error.unsqueeze(-1) | ||
if tensordict.device is not None: | ||
td_error = td_error.to(tensordict.device) | ||
|
||
tensordict.set( | ||
self.tensor_keys.priority, | ||
td_error, | ||
inplace=True, | ||
) | ||
loss = distance_loss(pred_val_index, target_value, self.loss_function).mean() | ||
|
||
metadata = { | ||
"td_error": td_error.mean(0).detach(), | ||
"loss_cql": cql_loss.item(), | ||
"pred_value": pred_val.mean().detach(), | ||
"target_value": target_value.mean().detach(), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the cql_loss
used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, I must have deleted it. Sorry for the confusion, I just updated and fixed it :)
What do you think of BY571#1? I think being able to run ablation studies has some value. We need to fix the categorical case. |
I think yes, if someone wants to check how the cql loss term influences the agent performance and want to have simple "on/off" capability it makes sense. The changes you did look good, I also pushed some adaption for the categorical case to calculate the cql loss. |
Cool LMK when you've merge the PR |
That looks great! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool let's merge this!
Description
Adds discrete (DQN) CQL objective and example
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!