Skip to content

Commit

Permalink
[Feature] LSTMModule (pytorch#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 24, 2023
1 parent 4d47d46 commit 3486827
Show file tree
Hide file tree
Showing 7 changed files with 707 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ algorithms, such as DQN, DDPG or Dreamer.
DdpgMlpActor
DdpgMlpQNet
DreamerActor
LSTMModule
ObsEncoder
ObsDecoder
RSSMPrior
Expand Down
199 changes: 197 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
CompositeSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs.utils import set_exploration_type
from torchrl.modules import NormalParamWrapper, SafeModule, TanhNormal
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import LSTMModule, NormalParamWrapper, SafeModule, TanhNormal
from torchrl.modules.tensordict_module.common import (
ensure_tensordict_compatible,
is_tensordict_compatible,
Expand Down Expand Up @@ -1523,6 +1523,201 @@ def forward(self, in_1, in_2):
assert isinstance(ensured_module, TensorDictModule)


class TestLSTMModule:
def test_errs(self):
with pytest.raises(ValueError, match="batch_first"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=False,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
with pytest.raises(ValueError, match="in_keys"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_keys=[
"observation",
"hidden0",
],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
with pytest.raises(ValueError, match="in_keys"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_keys="abc",
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
with pytest.raises(ValueError, match="in_keys"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_key="smth",
in_keys=[
"observation",
"hidden0",
],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
with pytest.raises(ValueError, match="out_keys"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0")],
)
with pytest.raises(ValueError, match="out_keys"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys="abc",
)
with pytest.raises(ValueError, match="out_keys"):
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_key="smth",
out_keys=["intermediate", ("next", "hidden0")],
)
lstm_module = LSTMModule(
input_size=3,
hidden_size=64,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
td = TensorDict({"observation": torch.randn(3)}, [])
with pytest.raises(KeyError, match="is_init"):
lstm_module(td)

def test_set_temporal_mode(self):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
assert lstm_module.set_recurrent_mode(False) is lstm_module
assert not lstm_module.set_recurrent_mode(False).temporal_mode
assert lstm_module.set_recurrent_mode(True) is not lstm_module
assert lstm_module.set_recurrent_mode(True).temporal_mode
assert set(lstm_module.set_recurrent_mode(True).parameters()) == set(
lstm_module.parameters()
)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_singel_step(self, shape):
td = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
td = lstm_module(td)
td_next = step_mdp(td, keep_other=True)
td_next = lstm_module(td_next)
assert not torch.isclose(
td_next["next", "hidden0"], td["next", "hidden0"]
).any()

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
@pytest.mark.parametrize("t", [1, 10])
def test_single_step_vs_multi(self, shape, t):
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
.unsqueeze(-1)
.expand(*shape, t, 3),
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
},
[*shape, t],
)
lstm_module_ss = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
lstm_module_ms = lstm_module_ss.set_recurrent_mode()
lstm_module_ms(td)
td_ss = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
for _t in range(t):
lstm_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
torch.testing.assert_close(
td_ss["hidden0"], td["next", "hidden0"][..., -1, :, :]
)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_multi_consecutive(self, shape):
t = 20
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
.unsqueeze(-1)
.expand(*shape, t, 3),
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
},
[*shape, t],
)
if shape:
td["is_init"][0, ..., 13, :] = True
else:
td["is_init"][13, :] = True

lstm_module_ss = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
lstm_module_ms = lstm_module_ss.set_recurrent_mode()
lstm_module_ms(td)
td_ss = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
for _t in range(t):
td_ss["is_init"][:] = td["is_init"][..., _t, :]
lstm_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
torch.testing.assert_close(
td_ss["intermediate"], td["intermediate"][..., -1, :]
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DistributionalQValueHook,
DistributionalQValueModule,
EGreedyWrapper,
LSTMModule,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
Expand Down
4 changes: 4 additions & 0 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,10 @@ def __init__(
mlp_kwargs: Dict,
device: Optional[DEVICE_TYPING] = None,
) -> None:
warnings.warn(
"LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.",
category=DeprecationWarning,
)
super().__init__()
lstm_kwargs.update({"batch_first": True})
self.mlp = MLP(device=device, **mlp_kwargs)
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from .rnn import LSTMModule
from .sequence import SafeSequential
from .world_models import WorldModelWrapper
Loading

0 comments on commit 3486827

Please sign in to comment.