Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Avoid reshape(-1) for inputs to DreamerActorLoss #2496

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[BugFix] Avoid reshape(-1) for inputs to DreamerActorLoss
  • Loading branch information
kurtamohler committed Oct 16, 2024
commit bca6b799ed39cfcf350987f2bb51e7eb92de8f0f
4 changes: 3 additions & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def compile_rssms(module):
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext():
actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict)
actor_loss_td, sampled_tensordict = actor_loss(
sampled_tensordict.reshape(-1)
)

actor_opt.zero_grad()
if use_autocast:
Expand Down
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10332,7 +10332,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est)
return
if td_est is not None:
loss_module.make_value_estimator(td_est)
loss_td, fake_data = loss_module(tensordict)
loss_td, fake_data = loss_module(tensordict.reshape(-1))
Copy link
Collaborator Author

@kurtamohler kurtamohler Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there is a better way to fix this test. I suppose it could be possible to just reshape the direct input to the GRUCell?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we need to reshape we should reshape - but another option here would be to use vmap
like:

if tensordict.ndim > 1:
    loss_td, fake_data = vmap(loss_module, (0,))(tensordict)

(gru works with vmap as long as you are using the python only version in torchrl.modules)

Copy link
Collaborator Author

@kurtamohler kurtamohler Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try using VmapModule, I get this error:

  File "/home/endoplasm/develop/torchrl-1/test/test_cost.py", line 10338, in test_dreamer_actor
    loss_td, fake_data = VmapModule(loss_module, (0,))(tensordict)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/endoplasm/develop/torchrl-1/torchrl/modules/tensordict_module/common.py", line 454, in __init__
    self.in_keys = module.in_keys
                   ^^^^^^^^^^^^^^
  File "/home/endoplasm/develop/torchrl-1/torchrl/objectives/common.py", line 441, in __getattr__
    return super().__getattr__(item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/endoplasm/miniconda/envs/torchrl-1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'DreamerActorLoss' object has no attribute 'in_keys'

I'll probably just leave the reshape for now, but I would like to understand this.

Indeed if I try to access loss_module.in_keys directly, I also get the above error. But I can access the in_keys of the actor model and world model within the loss module:

print(loss_module.actor_model.in_keys)
print(loss_module.model_based_env.world_model.in_keys)
['state', 'belief']
['state', 'belief', 'action']

So I'm wondering what would be the right way to make VmapModule and DreamerActorLoss compatible? Would we want to add an in_keys attribute to DreamerActorLoss that returns a combined list of the keys in the actor model and world model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah DreamerActorLoss should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then

assert not fake_data.requires_grad
assert fake_data.shape == torch.Size([tensordict.numel(), imagination_horizon])
if discount_loss:
Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]:
tensordict = tensordict.select("state", self.tensor_keys.belief).detach()
tensordict = tensordict.reshape(-1)

with timeit("actor_loss/time-rollout"), hold_out_net(
self.model_based_env
Expand Down
Loading