-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix missing ("next", "observation") key in dispatch of losses #1235
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
see my minor comment
torchrl/objectives/a2c.py
Outdated
>>> loss_val | ||
(tensor(1.7593, grad_fn=<MeanBackward0>), tensor(0.2344, grad_fn=<MeanBackward0>), tensor(1.5480), tensor(-0.0155, grad_fn=<MulBackward0>)) | ||
(tensor(4.3483, grad_fn=<MeanBackward0>), tensor(1.4114, grad_fn=<MeanBackward0>), tensor(2.5165), tensor(-0.0252, grad_fn=<MulBackward0>)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we don't want to print these values (such that we don't pollute PRs with diffs that aren't relevant?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I will just remove it.
torchrl/objectives/iql.py
Outdated
... next_reward=torch.randn(*batch, 1)) | ||
>>> loss_val | ||
(tensor(1.4535, grad_fn=<MeanBackward0>), tensor(0.8389, grad_fn=<MeanBackward0>), tensor(0.3406, grad_fn=<MeanBackward0>), tensor(3.3441)) | ||
(tensor(1.4535, grad_fn=<MeanBackward0>), tensor(0.7506, grad_fn=<MeanBackward0>), tensor(0.3406, grad_fn=<MeanBackward0>), tensor(3.3441)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
torchrl/objectives/a2c.py
Outdated
>>> loss_val | ||
(tensor(4.3483, grad_fn=<MeanBackward0>), tensor(1.4114, grad_fn=<MeanBackward0>), tensor(2.5165), tensor(-0.0252, grad_fn=<MulBackward0>)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can still print something :)
You could do
loss_actor, loss_val, *etc = fun_call(**kwargs)
loss_actor.backward() # for instance
anything that makes it feel like "this is a tensor"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
Description
Fix missing ("next", "dispatch") key of in_keys from IQL, DDPG, and A2C loss module.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!