Skip to content

Commit

Permalink
Single batch single step LSTM bug fix (pytorch#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 31, 2022
1 parent d40b08d commit 40b2b0a
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 13 deletions.
18 changes: 16 additions & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse

_configargparse = False

import torch.cuda
from torch.utils.tensorboard import SummaryWriter
from torchrl.envs.transforms import RewardScaling, TransformedEnv
Expand Down Expand Up @@ -100,7 +101,11 @@ def main(args):
args=args, use_env_creator=False, stats=stats
)()

model = make_ppo_model(proof_env, args=args, device=device)
model = make_ppo_model(
proof_env,
args=args,
device=device,
)
actor_model = model.get_policy_operator()

loss_module = make_ppo_loss(model, args)
Expand All @@ -112,6 +117,7 @@ def main(args):
del proof_td
else:
action_dim_gsde, state_dim_gsde = None, None

proof_env.close()
create_env_fn = parallel_env_constructor(
args=args,
Expand All @@ -136,6 +142,7 @@ def main(args):
norm_obs_only=True,
stats=stats,
writer=writer,
use_env_creator=False,
)()

# remove video recorder from recorder to have matching state_dict keys
Expand All @@ -159,7 +166,14 @@ def main(args):
t.loc.fill_(0.0)

trainer = make_trainer(
collector, loss_module, recorder, None, actor_model, None, writer, args
collector,
loss_module,
recorder,
None,
actor_model,
None,
writer,
args,
)
if args.loss == "kl":
trainer.register_op("pre_optim_steps", loss_module.reset)
Expand Down
130 changes: 128 additions & 2 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from numbers import Number

import pytest
Expand All @@ -17,6 +17,7 @@
TensorDictModule,
ValueOperator,
ProbabilisticActor,
LSTMNet,
)
from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear

Expand Down Expand Up @@ -176,5 +177,130 @@ def test_actorcritic(device):
) == len(policy_params)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("out_features", [3, 4])
@pytest.mark.parametrize("hidden_size", [8, 9])
@pytest.mark.parametrize("num_layers", [1, 2])
@pytest.mark.parametrize("has_precond_hidden", [True, False])
def test_lstm_net(device, out_features, hidden_size, num_layers, has_precond_hidden):
batch = 5
time_steps = 6
in_features = 7
net = LSTMNet(
out_features,
{
"input_size": hidden_size,
"hidden_size": hidden_size,
"num_layers": num_layers,
},
{"out_features": hidden_size},
).to(device)
# test single step vs multi-step
x = torch.randn(batch, time_steps, in_features, device=device)
x_unbind = x.unbind(1)
tds_loop = []
if has_precond_hidden:
hidden0_out0, hidden1_out0 = torch.randn(
2, batch, time_steps, num_layers, hidden_size, device=device
)
hidden0_out0[:, 1:] = 0.0
hidden1_out0[:, 1:] = 0.0
hidden0_out = hidden0_out0[:, 0]
hidden1_out = hidden1_out0[:, 0]
else:
hidden0_out, hidden1_out = None, None
hidden0_out0, hidden1_out0 = None, None

for _x in x_unbind:
y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(
_x, hidden0_out, hidden1_out
)
td = TensorDict(
{
"y": y,
"hidden0_in": hidden0_in,
"hidden1_in": hidden1_in,
"hidden0_out": hidden0_out,
"hidden1_out": hidden1_out,
},
[batch],
)
tds_loop.append(td)
tds_loop = torch.stack(tds_loop, 1)

y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(
x, hidden0_out0, hidden1_out0
)
tds_vec = TensorDict(
{
"y": y,
"hidden0_in": hidden0_in,
"hidden1_in": hidden1_in,
"hidden0_out": hidden0_out,
"hidden1_out": hidden1_out,
},
[batch, time_steps],
)
torch.testing.assert_close(tds_vec["y"], tds_loop["y"])
torch.testing.assert_close(
tds_vec["hidden0_out"][:, -1], tds_loop["hidden0_out"][:, -1]
)
torch.testing.assert_close(
tds_vec["hidden1_out"][:, -1], tds_loop["hidden1_out"][:, -1]
)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("out_features", [3, 5])
@pytest.mark.parametrize("hidden_size", [3, 5])
def test_lstm_net_nobatch(device, out_features, hidden_size):
time_steps = 6
in_features = 4
net = LSTMNet(
out_features,
{"input_size": hidden_size, "hidden_size": hidden_size},
{"out_features": hidden_size},
).to(device)
# test single step vs multi-step
x = torch.randn(time_steps, in_features, device=device)
x_unbind = x.unbind(0)
tds_loop = []
hidden0_in, hidden1_in, hidden0_out, hidden1_out = [
None,
] * 4
for _x in x_unbind:
y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(
_x, hidden0_out, hidden1_out
)
td = TensorDict(
{
"y": y,
"hidden0_in": hidden0_in,
"hidden1_in": hidden1_in,
"hidden0_out": hidden0_out,
"hidden1_out": hidden1_out,
},
[],
)
tds_loop.append(td)
tds_loop = torch.stack(tds_loop, 0)

y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x.unsqueeze(0))
tds_vec = TensorDict(
{
"y": y,
"hidden0_in": hidden0_in,
"hidden1_in": hidden1_in,
"hidden0_out": hidden0_out,
"hidden1_out": hidden1_out,
},
[1, time_steps],
).squeeze(0)
torch.testing.assert_close(tds_vec["y"], tds_loop["y"])
torch.testing.assert_close(tds_vec["hidden0_out"][-1], tds_loop["hidden0_out"][-1])
torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1])


if __name__ == "__main__":
pytest.main([__file__])
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 3 additions & 1 deletion test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse

import pytest
import torch
Expand Down Expand Up @@ -167,4 +168,5 @@ def test_splits(self, num_workers, traj_len):


if __name__ == "__main__":
pytest.main([__file__, "--capture", "no"])
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
10 changes: 10 additions & 0 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,8 @@ def memmap_(self) -> _TensorDict:

def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorDict:
if isinstance(dest, type) and issubclass(dest, _TensorDict):
if isinstance(self, dest):
return self
td = dest(
source=self,
**kwargs,
Expand Down Expand Up @@ -2270,6 +2272,8 @@ def set_(

def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorDict:
if isinstance(dest, type) and issubclass(dest, _TensorDict):
if isinstance(self, dest):
return self
return dest(
source=self.clone(),
)
Expand Down Expand Up @@ -2744,6 +2748,8 @@ def pin_memory(self) -> _TensorDict:

def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict:
if isinstance(dest, type) and issubclass(dest, _TensorDict):
if isinstance(self, dest):
return self
return dest(source=self, batch_size=self.batch_size)
elif isinstance(dest, (torch.device, str, int)):
dest = torch.device(dest)
Expand Down Expand Up @@ -3177,6 +3183,8 @@ def __repr__(self) -> str:

def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs):
if isinstance(dest, type) and issubclass(dest, _TensorDict):
if isinstance(self, dest):
return self
td = dest(
source=TensorDict(self.to_dict(), batch_size=self.batch_size),
**kwargs,
Expand Down Expand Up @@ -3482,6 +3490,8 @@ def del_(self, key: str) -> _CustomOpTensorDict:

def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict:
if isinstance(dest, type) and issubclass(dest, _TensorDict):
if isinstance(self, dest):
return self
return dest(source=self.contiguous().clone())
elif isinstance(dest, (torch.device, str, int)):
if self._device_safe() is not None and torch.device(dest) == self.device:
Expand Down
48 changes: 40 additions & 8 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,11 +883,32 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens


class LSTMNet(nn.Module):
"""
An embedder for an LSTM followed by an MLP.
"""An embedder for an LSTM preceded by an MLP.
The forward method returns the hidden states of the current state (input hidden states) and the output, as
the environment returns the 'observation' and 'next_observation'.
Because the LSTM kernel only returns the last hidden state, hidden states
are padded with zeros such that they have the right size to be stored in a
TensorDict of size [batch x time_steps].
If a 2D tensor is provided as input, it is assumed that it is a batch of data
with only one time step. This means that we explicitely assume that users will
unsqueeze inputs of a single batch with multiple time steps.
Examples:
>>> batch = 7
>>> time_steps = 6
>>> in_features = 4
>>> net = LSTMNet(
... out_features,
... {"input_size": hidden_size, "hidden_size": hidden_size},
... {"out_features": hidden_size},
... )
>>> # test single step vs multi-step
>>> x = torch.randn(batch, time_steps, in_features)
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
"""

def __init__(self, out_features, lstm_kwargs: Dict, mlp_kwargs: Dict) -> None:
Expand All @@ -903,14 +924,19 @@ def _lstm(
hidden0_in: Optional[torch.Tensor] = None,
hidden1_in: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
squeeze = False
squeeze0 = False
squeeze1 = False
if input.ndimension() == 1:
squeeze0 = True
input = input.unsqueeze(0).contiguous()

if input.ndimension() == 2:
squeeze = True
squeeze1 = True
input = input.unsqueeze(1).contiguous()
batch, steps = input.shape[:2]

if hidden1_in is None and hidden0_in is None:
shape = (batch, steps) if not squeeze else (batch,)
shape = (batch, steps) if not squeeze1 else (batch,)
hidden0_in, hidden1_in = [
torch.zeros(
*shape,
Expand All @@ -925,9 +951,12 @@ def _lstm(
raise RuntimeError(
f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}"
)
elif squeeze0:
hidden0_in = hidden0_in.unsqueeze(0)
hidden1_in = hidden1_in.unsqueeze(0)

# we only need the first hidden state
if not squeeze:
if not squeeze1:
_hidden0_in = hidden0_in[:, 0]
_hidden1_in = hidden1_in[:, 0]
else:
Expand All @@ -944,16 +973,19 @@ def _lstm(
y = self.linear(y0)

out = [y, hidden0_in, hidden1_in, *hidden]
if squeeze:
if squeeze1:
# squeezes time
out[0] = out[0].squeeze(1)
else:
if not squeeze1:
# we pad the hidden states with zero to make tensordict happy
for i in range(3, 5):
out[i] = torch.stack(
[torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)]
+ [out[i]],
1,
)
if squeeze0:
out = [_out.squeeze(0) for _out in out]
return tuple(out)

def forward(
Expand Down
1 change: 1 addition & 0 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def make_trainer(
"post_steps_log",
recorder_obj,
)
recorder_obj(None)
recorder_obj_explore = Recorder(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
Expand Down
1 change: 1 addition & 0 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def update_reward_stats(self, batch: _TensorDict) -> None:
self._update_has_been_called = True

def normalize_reward(self, tensordict: _TensorDict) -> _TensorDict:
tensordict = tensordict.to_tensordict() # make sure it is not a SubTensorDict
reward = tensordict.get("reward")
reward = reward - self._reward_stats["mean"].to(tensordict.device)
reward = reward / self._reward_stats["std"].to(tensordict.device)
Expand Down

0 comments on commit 40b2b0a

Please sign in to comment.