Skip to content

Commit

Permalink
Fix LSTM use with padded segments
Browse files Browse the repository at this point in the history
  • Loading branch information
smorad committed Jul 19, 2023
1 parent c96227a commit 3480805
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def forward(self, tensordict: TensorDictBase):
batch_size=[nelts, tensordict_shaped.shape[-1]],
)
else:
tensordict_shaped = tensordict.view(-1).unsqueeze(-1)
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)

is_init = tensordict_shaped.get("is_init").squeeze(-1)
splits = None
Expand Down Expand Up @@ -340,7 +340,7 @@ def forward(self, tensordict: TensorDictBase):
tensordict_shaped.set(self.out_keys[2], hidden1)
if splits is not None:
# let's recover our original shape
tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).view(
tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape(
tensordict_shaped_shape
)

Expand Down

0 comments on commit 3480805

Please sign in to comment.