Skip to content

Commit

Permalink
[BugFix] Fix LSTM - VecEnv compatibility (pytorch#1427)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 2, 2023
1 parent dbab7bb commit 2982515
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 20 deletions.
14 changes: 7 additions & 7 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,12 @@ def __new__(
shape=batch_size,
)
if action_spec is None:
action_spec_cls = (
DiscreteTensorSpec
if categorical_action_encoding
else OneHotDiscreteTensorSpec
)
action_spec = action_spec_cls(n=7, shape=(*batch_size, 7))
if categorical_action_encoding:
action_spec_cls = DiscreteTensorSpec
action_spec = action_spec_cls(n=7, shape=batch_size)
else:
action_spec_cls = OneHotDiscreteTensorSpec
action_spec = action_spec_cls(n=7, shape=(*batch_size, 7))
if reward_spec is None:
reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
if done_spec is None:
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def _step(
batch_size=self.batch_size,
device=self.device,
)
return tensordict.select().set("next", tensordict)
return tensordict


class NestedCountingEnv(CountingEnv):
Expand Down
2 changes: 1 addition & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_env_seed(env_name, frame_skip, seed=0):
env.set_seed(seed)
td0b = env.fake_tensordict()
td0b = env.reset(tensordict=td0b)
td1b = env.step(td0b.clone().set("action", action))
td1b = env.step(td0b.exclude("next").clone().set("action", action))

assert_allclose_td(td0a, td0b.select(*td0a.keys()))
assert_allclose_td(td1a, td1b)
Expand Down
58 changes: 57 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

import pytest
import torch
from mocking_classes import DiscreteActionVecMockEnv
from tensordict import pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, make_functional, TensorDictModule
from tensordict.nn import (
InteractionType,
make_functional,
TensorDictModule,
TensorDictSequential,
)
from torch import nn
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
Expand All @@ -21,6 +27,7 @@
DecisionTransformerInferenceWrapper,
DTActor,
LSTMModule,
MLP,
NormalParamWrapper,
OnlineDTActor,
ProbabilisticActor,
Expand Down Expand Up @@ -1765,6 +1772,55 @@ def test_multi_consecutive(self, shape):
td_ss["intermediate"], td["intermediate"][..., -1, :]
)

def test_lstm_parallel_env(self):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
input_size=7,
hidden_size=12,
num_layers=2,
in_key="observation",
out_key="features",
)

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(categorical_action_encoding=True)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

env = ParallelEnv(
create_env_fn=create_transformed_env,
num_workers=2,
)

mlp = TensorDictModule(
MLP(
in_features=12,
out_features=7,
num_cells=[],
),
in_keys=["features"],
out_keys=["logits"],
)

actor_model = TensorDictSequential(lstm_module, mlp)

actor = ProbabilisticActor(
module=actor_model,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=True,
)
for break_when_any_done in [False, True]:
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state_c") != 0.0).any()
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()


def test_safe_specs():

Expand Down
11 changes: 7 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,9 @@ def test_transform_env_clone(self):
value_at_clone = td["next", "observation"].clone()
for _ in range(10):
td = env.rand_step(td)
assert (td["next", "observation"] != value_at_clone).any()
assert (
td["next", "observation"] == env.transform._cat_buffers_observation
).all()
td = step_mdp(td)
assert (td["observation"] != value_at_clone).any()
assert (td["observation"] == env.transform._cat_buffers_observation).all()
assert (
cloned._cat_buffers_observation == env.transform._cat_buffers_observation
).all()
Expand Down Expand Up @@ -6693,6 +6692,7 @@ def _test_vecnorm_subproc_auto(
tensordict = env.reset()
for _ in range(10):
tensordict = env.rand_step(tensordict)
tensordict = step_mdp(tensordict)
queue_out.put(True)
msg = queue_in.get(timeout=TIMEOUT)
assert msg == "all_done"
Expand Down Expand Up @@ -6800,11 +6800,13 @@ def _run_parallelenv(parallel_env, queue_in, queue_out):
assert msg == "start"
for _ in range(10):
tensordict = parallel_env.rand_step(tensordict)
tensordict = step_mdp(tensordict)
queue_out.put("first round")
msg = queue_in.get(timeout=TIMEOUT)
assert msg == "start"
for _ in range(10):
tensordict = parallel_env.rand_step(tensordict)
tensordict = step_mdp(tensordict)
queue_out.put("second round")
parallel_env.close()
queue_out.close()
Expand Down Expand Up @@ -6884,6 +6886,7 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200):
for _ in range(N):
td = env_t.rand_step(td)
tds.append(td.clone())
td = step_mdp(td)
if td.get("done").any():
td = env_t.reset()
tds = torch.stack(tds, 0)
Expand Down
19 changes: 14 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
Args:
tensordict (TensorDictBase): Tensordict containing the action to be taken.
If the input tensordict contains a ``"next"`` entry, the values contained in it
will prevail over the newly computed values. This gives a mechanism
to override the underlying computations.
Returns:
the input tensordict, modified in place with the resulting observations, done state and reward
Expand All @@ -1126,10 +1129,13 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
# sanity check
self._assert_tensordict_shape(tensordict)
next_preset = tensordict.get("next", None)

next_tensordict = self._step(tensordict)
next_tensordict = self._step_proc_data(next_tensordict)
# tensordict could already have a "next" key
if next_preset is not None:
# tensordict could already have a "next" key
next_tensordict.update(next_preset)
tensordict.set("next", next_tensordict)
return tensordict

Expand Down Expand Up @@ -1669,11 +1675,14 @@ def fake_tensordict(self) -> TensorDictBase:
next_output.update(fake_reward)
next_output.update(fake_done)
fake_in_out.update(fake_done.clone())
if "next" not in fake_in_out.keys():
fake_in_out.set("next", next_output)
else:
fake_in_out.get("next").update(next_output)

fake_td = fake_in_out.set("next", next_output)
fake_td.batch_size = self.batch_size
fake_td = fake_td.to(self.device)
return fake_td
fake_in_out.batch_size = self.batch_size
fake_in_out = fake_in_out.to(self.device)
return fake_in_out


class _EnvWrapper(EnvBase, metaclass=abc.ABCMeta):
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import Optional, Tuple

import torch
from tensordict import unravel_key_list
from tensordict import TensorDictBase, unravel_key_list

from tensordict.nn import TensorDictModuleBase as ModuleBase

from tensordict.tensordict import NO_DEFAULT, TensorDictBase
from tensordict.tensordict import NO_DEFAULT
from tensordict.utils import prod

from torch import nn
Expand Down

0 comments on commit 2982515

Please sign in to comment.