Skip to content

Commit

Permalink
[BugFix] Fix LSTM use with padded/masked segments (pytorch#1399)
Browse files Browse the repository at this point in the history
  • Loading branch information
smorad authored Jul 26, 2023
1 parent c06ed70 commit dc8b7b5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
20 changes: 19 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest
import torch
from tensordict import TensorDict, unravel_key_list
from tensordict import pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, make_functional, TensorDictModule
from torch import nn
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -1634,6 +1634,24 @@ def test_set_temporal_mode(self):
lstm_module.parameters()
)

def test_noncontiguous(self):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["bork", "h0", "h1"],
out_keys=["dork", ("next", "h0"), ("next", "h1")],
)
td = TensorDict(
{
"bork": torch.randn(3, 3),
"is_init": torch.zeros(3, 1, dtype=torch.bool),
},
[3],
)
padded = pad(td, [0, 5])
lstm_module(padded)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_singel_step(self, shape):
td = TensorDict(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def forward(self, tensordict: TensorDictBase):
batch_size=[nelts, tensordict_shaped.shape[-1]],
)
else:
tensordict_shaped = tensordict.view(-1).unsqueeze(-1)
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)

is_init = tensordict_shaped.get("is_init").squeeze(-1)
splits = None
Expand Down Expand Up @@ -340,7 +340,7 @@ def forward(self, tensordict: TensorDictBase):
tensordict_shaped.set(self.out_keys[2], hidden1)
if splits is not None:
# let's recover our original shape
tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).view(
tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape(
tensordict_shaped_shape
)

Expand Down

0 comments on commit dc8b7b5

Please sign in to comment.