From a27514c23403cad295ec9ed2927b08d4a5d2767c Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Fri, 18 Oct 2024 06:41:27 -0700 Subject: [PATCH] [BugFix] Avoid `reshape(-1)` for inputs to `DreamerActorLoss` (#2496) --- sota-implementations/dreamer/dreamer.py | 4 +++- test/test_cost.py | 2 +- torchrl/objectives/dreamer.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f28fac8e675..992abea64e0 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -217,7 +217,9 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16 ) if use_autocast else contextlib.nullcontext(): - actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) + actor_loss_td, sampled_tensordict = actor_loss( + sampled_tensordict.reshape(-1) + ) actor_opt.zero_grad() if use_autocast: diff --git a/test/test_cost.py b/test/test_cost.py index 3530fff825d..6d8de531d49 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10332,7 +10332,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) return if td_est is not None: loss_module.make_value_estimator(td_est) - loss_td, fake_data = loss_module(tensordict) + loss_td, fake_data = loss_module(tensordict.reshape(-1)) assert not fake_data.requires_grad assert fake_data.shape == torch.Size([tensordict.numel(), imagination_horizon]) if discount_loss: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 30f6dd10699..73df58b7e56 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -271,7 +271,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.select("state", self.tensor_keys.belief).detach() - tensordict = tensordict.reshape(-1) with timeit("actor_loss/time-rollout"), hold_out_net( self.model_based_env