diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 9e7e4421844..d511b069612 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -301,7 +301,7 @@ def forward(self, tensordict: TensorDictBase): batch_size=[nelts, tensordict_shaped.shape[-1]], ) else: - tensordict_shaped = tensordict.view(-1).unsqueeze(-1) + tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) is_init = tensordict_shaped.get("is_init").squeeze(-1) splits = None @@ -340,7 +340,7 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped.set(self.out_keys[2], hidden1) if splits is not None: # let's recover our original shape - tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).view( + tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape( tensordict_shaped_shape )