-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Avoid reshape(-1)
for inputs to DreamerActorLoss
#2496
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2496
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 6 Unrelated FailuresAs of commit bca6b79 with merge base d894358 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
loss_td, fake_data = loss_module(tensordict) | ||
# NOTE: Input is reshaped because GRUCell (which is part of the | ||
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D | ||
loss_td, fake_data = loss_module(tensordict.reshape(-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if there is a better way to fix this test. I suppose it could be possible to just reshape the direct input to the GRUCell?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we need to reshape we should reshape - but another option here would be to use vmap
like:
if tensordict.ndim > 1:
loss_td, fake_data = vmap(loss_module, (0,))(tensordict)
(gru works with vmap as long as you are using the python only version in torchrl.modules)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I try using VmapModule
, I get this error:
File "/home/endoplasm/develop/torchrl-1/test/test_cost.py", line 10338, in test_dreamer_actor
loss_td, fake_data = VmapModule(loss_module, (0,))(tensordict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/endoplasm/develop/torchrl-1/torchrl/modules/tensordict_module/common.py", line 454, in __init__
self.in_keys = module.in_keys
^^^^^^^^^^^^^^
File "/home/endoplasm/develop/torchrl-1/torchrl/objectives/common.py", line 441, in __getattr__
return super().__getattr__(item)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/endoplasm/miniconda/envs/torchrl-1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
raise AttributeError(
AttributeError: 'DreamerActorLoss' object has no attribute 'in_keys'
I'll probably just leave the reshape for now, but I would like to understand this.
Indeed if I try to access loss_module.in_keys
directly, I also get the above error. But I can access the in_keys
of the actor model and world model within the loss module:
print(loss_module.actor_model.in_keys)
print(loss_module.model_based_env.world_model.in_keys)
['state', 'belief']
['state', 'belief', 'action']
So I'm wondering what would be the right way to make VmapModule
and DreamerActorLoss
compatible? Would we want to add an in_keys
attribute to DreamerActorLoss
that returns a combined list of the keys in the actor model and world model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah DreamerActorLoss
should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then
The Dreamer implementation (in examples workflow) is failing |
6009810
to
bca6b79
Compare
Should be fixed now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
loss_td, fake_data = loss_module(tensordict) | ||
# NOTE: Input is reshaped because GRUCell (which is part of the | ||
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D | ||
loss_td, fake_data = loss_module(tensordict.reshape(-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah DreamerActorLoss
should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then
Description
Avoid reshaping inputs to
DreamerActorLoss
.Motivation and Context
Follow-up to #2494
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!