Skip to content

Commit

Permalink
[Feature] Support for GRU (pytorch#1586)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 5, 2023
1 parent f62785b commit 244f93a
Show file tree
Hide file tree
Showing 5 changed files with 682 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ algorithms, such as DQN, DDPG or Dreamer.
DistributionalDQNnet
DreamerActor
DuelingCnnDQNet
GRUModule
LSTMModule
ObsDecoder
ObsEncoder
Expand Down
262 changes: 260 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AdditiveGaussianWrapper,
DecisionTransformerInferenceWrapper,
DTActor,
GRUModule,
LSTMModule,
MLP,
NormalParamWrapper,
Expand Down Expand Up @@ -1645,9 +1646,9 @@ def test_set_temporal_mode(self):
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
assert lstm_module.set_recurrent_mode(False) is lstm_module
assert not lstm_module.set_recurrent_mode(False).temporal_mode
assert not lstm_module.set_recurrent_mode(False).recurrent_mode
assert lstm_module.set_recurrent_mode(True) is not lstm_module
assert lstm_module.set_recurrent_mode(True).temporal_mode
assert lstm_module.set_recurrent_mode(True).recurrent_mode
assert set(lstm_module.set_recurrent_mode(True).parameters()) == set(
lstm_module.parameters()
)
Expand Down Expand Up @@ -1822,6 +1823,263 @@ def create_transformed_env():
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()


class TestGRUModule:
def test_errs(self):
with pytest.raises(ValueError, match="batch_first"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=False,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(ValueError, match="in_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=[
"observation",
"hidden0",
"hidden1",
],
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(TypeError, match="incompatible function arguments"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys="abc",
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(ValueError, match="in_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_key="smth",
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(ValueError, match="out_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden"), "other"],
)
with pytest.raises(TypeError, match="incompatible function arguments"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys="abc",
)
with pytest.raises(ValueError, match="out_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_key="smth",
out_keys=["intermediate", ("next", "hidden"), "other"],
)
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
td = TensorDict({"observation": torch.randn(3)}, [])
with pytest.raises(KeyError, match="is_init"):
gru_module(td)

def test_set_temporal_mode(self):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
assert gru_module.set_recurrent_mode(False) is gru_module
assert not gru_module.set_recurrent_mode(False).recurrent_mode
assert gru_module.set_recurrent_mode(True) is not gru_module
assert gru_module.set_recurrent_mode(True).recurrent_mode
assert set(gru_module.set_recurrent_mode(True).parameters()) == set(
gru_module.parameters()
)

def test_noncontiguous(self):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["bork", "h"],
out_keys=["dork", ("next", "h")],
)
td = TensorDict(
{
"bork": torch.randn(3, 3),
"is_init": torch.zeros(3, 1, dtype=torch.bool),
},
[3],
)
padded = pad(td, [0, 5])
gru_module(padded)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_singel_step(self, shape):
td = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
td = gru_module(td)
td_next = step_mdp(td, keep_other=True)
td_next = gru_module(td_next)

assert not torch.isclose(td_next["next", "hidden"], td["next", "hidden"]).any()

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
@pytest.mark.parametrize("t", [1, 10])
def test_single_step_vs_multi(self, shape, t):
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
.unsqueeze(-1)
.expand(*shape, t, 3),
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
},
[*shape, t],
)
gru_module_ss = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
gru_module_ms = gru_module_ss.set_recurrent_mode()
gru_module_ms(td)
td_ss = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
for _t in range(t):
gru_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :])

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_multi_consecutive(self, shape):
t = 20
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
.unsqueeze(-1)
.expand(*shape, t, 3),
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
},
[*shape, t],
)
if shape:
td["is_init"][0, ..., 13, :] = True
else:
td["is_init"][13, :] = True

gru_module_ss = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
gru_module_ms = gru_module_ss.set_recurrent_mode()
gru_module_ms(td)
td_ss = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
for _t in range(t):
td_ss["is_init"][:] = td["is_init"][..., _t, :]
gru_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
torch.testing.assert_close(
td_ss["intermediate"], td["intermediate"][..., -1, :]
)

def test_gru_parallel_env(self):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

# tests that hidden states are carried over with parallel envs
gru_module = GRUModule(
input_size=7,
hidden_size=12,
num_layers=2,
in_key="observation",
out_key="features",
)

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(categorical_action_encoding=True)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

env = ParallelEnv(
create_env_fn=create_transformed_env,
num_workers=2,
)

mlp = TensorDictModule(
MLP(
in_features=12,
out_features=7,
num_cells=[],
),
in_keys=["features"],
out_keys=["logits"],
)

actor_model = TensorDictSequential(gru_module, mlp)

actor = ProbabilisticActor(
module=actor_model,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=True,
)
for break_when_any_done in [False, True]:
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()


def test_safe_specs():

out_key = ("a", "b")
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
DistributionalQValueModule,
EGreedyModule,
EGreedyWrapper,
GRUModule,
LMHeadActorValueOperator,
LSTMModule,
OrnsteinUhlenbeckProcessWrapper,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from .rnn import LSTMModule
from .rnn import GRUModule, LSTMModule
from .sequence import SafeSequential
from .world_models import WorldModelWrapper
Loading

0 comments on commit 244f93a

Please sign in to comment.