diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index 8065bf8770c..72520ed47c0 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -17,6 +17,7 @@ import argparse _configargparse = False + import torch.cuda from torch.utils.tensorboard import SummaryWriter from torchrl.envs.transforms import RewardScaling, TransformedEnv @@ -100,7 +101,11 @@ def main(args): args=args, use_env_creator=False, stats=stats )() - model = make_ppo_model(proof_env, args=args, device=device) + model = make_ppo_model( + proof_env, + args=args, + device=device, + ) actor_model = model.get_policy_operator() loss_module = make_ppo_loss(model, args) @@ -112,6 +117,7 @@ def main(args): del proof_td else: action_dim_gsde, state_dim_gsde = None, None + proof_env.close() create_env_fn = parallel_env_constructor( args=args, @@ -136,6 +142,7 @@ def main(args): norm_obs_only=True, stats=stats, writer=writer, + use_env_creator=False, )() # remove video recorder from recorder to have matching state_dict keys @@ -159,7 +166,14 @@ def main(args): t.loc.fill_(0.0) trainer = make_trainer( - collector, loss_module, recorder, None, actor_model, None, writer, args + collector, + loss_module, + recorder, + None, + actor_model, + None, + writer, + args, ) if args.loss == "kl": trainer.register_op("pre_optim_steps", loss_module.reset) diff --git a/test/test_modules.py b/test/test_modules.py index b8e022c0d28..180a78df67e 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import argparse from numbers import Number import pytest @@ -17,6 +17,7 @@ TensorDictModule, ValueOperator, ProbabilisticActor, + LSTMNet, ) from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear @@ -176,5 +177,130 @@ def test_actorcritic(device): ) == len(policy_params) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("out_features", [3, 4]) +@pytest.mark.parametrize("hidden_size", [8, 9]) +@pytest.mark.parametrize("num_layers", [1, 2]) +@pytest.mark.parametrize("has_precond_hidden", [True, False]) +def test_lstm_net(device, out_features, hidden_size, num_layers, has_precond_hidden): + batch = 5 + time_steps = 6 + in_features = 7 + net = LSTMNet( + out_features, + { + "input_size": hidden_size, + "hidden_size": hidden_size, + "num_layers": num_layers, + }, + {"out_features": hidden_size}, + ).to(device) + # test single step vs multi-step + x = torch.randn(batch, time_steps, in_features, device=device) + x_unbind = x.unbind(1) + tds_loop = [] + if has_precond_hidden: + hidden0_out0, hidden1_out0 = torch.randn( + 2, batch, time_steps, num_layers, hidden_size, device=device + ) + hidden0_out0[:, 1:] = 0.0 + hidden1_out0[:, 1:] = 0.0 + hidden0_out = hidden0_out0[:, 0] + hidden1_out = hidden1_out0[:, 0] + else: + hidden0_out, hidden1_out = None, None + hidden0_out0, hidden1_out0 = None, None + + for _x in x_unbind: + y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( + _x, hidden0_out, hidden1_out + ) + td = TensorDict( + { + "y": y, + "hidden0_in": hidden0_in, + "hidden1_in": hidden1_in, + "hidden0_out": hidden0_out, + "hidden1_out": hidden1_out, + }, + [batch], + ) + tds_loop.append(td) + tds_loop = torch.stack(tds_loop, 1) + + y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( + x, hidden0_out0, hidden1_out0 + ) + tds_vec = TensorDict( + { + "y": y, + "hidden0_in": hidden0_in, + "hidden1_in": hidden1_in, + "hidden0_out": hidden0_out, + "hidden1_out": hidden1_out, + }, + [batch, time_steps], + ) + torch.testing.assert_close(tds_vec["y"], tds_loop["y"]) + torch.testing.assert_close( + tds_vec["hidden0_out"][:, -1], tds_loop["hidden0_out"][:, -1] + ) + torch.testing.assert_close( + tds_vec["hidden1_out"][:, -1], tds_loop["hidden1_out"][:, -1] + ) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("out_features", [3, 5]) +@pytest.mark.parametrize("hidden_size", [3, 5]) +def test_lstm_net_nobatch(device, out_features, hidden_size): + time_steps = 6 + in_features = 4 + net = LSTMNet( + out_features, + {"input_size": hidden_size, "hidden_size": hidden_size}, + {"out_features": hidden_size}, + ).to(device) + # test single step vs multi-step + x = torch.randn(time_steps, in_features, device=device) + x_unbind = x.unbind(0) + tds_loop = [] + hidden0_in, hidden1_in, hidden0_out, hidden1_out = [ + None, + ] * 4 + for _x in x_unbind: + y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( + _x, hidden0_out, hidden1_out + ) + td = TensorDict( + { + "y": y, + "hidden0_in": hidden0_in, + "hidden1_in": hidden1_in, + "hidden0_out": hidden0_out, + "hidden1_out": hidden1_out, + }, + [], + ) + tds_loop.append(td) + tds_loop = torch.stack(tds_loop, 0) + + y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x.unsqueeze(0)) + tds_vec = TensorDict( + { + "y": y, + "hidden0_in": hidden0_in, + "hidden1_in": hidden1_in, + "hidden0_out": hidden0_out, + "hidden1_out": hidden1_out, + }, + [1, time_steps], + ).squeeze(0) + torch.testing.assert_close(tds_vec["y"], tds_loop["y"]) + torch.testing.assert_close(tds_vec["hidden0_out"][-1], tds_loop["hidden0_out"][-1]) + torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) + + if __name__ == "__main__": - pytest.main([__file__]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 5a976f409b2..b108b29943c 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse import pytest import torch @@ -167,4 +168,5 @@ def test_splits(self, num_workers, traj_len): if __name__ == "__main__": - pytest.main([__file__, "--capture", "no"]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index c311e40dce5..0920156926f 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -1817,6 +1817,8 @@ def memmap_(self) -> _TensorDict: def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorDict: if isinstance(dest, type) and issubclass(dest, _TensorDict): + if isinstance(self, dest): + return self td = dest( source=self, **kwargs, @@ -2270,6 +2272,8 @@ def set_( def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorDict: if isinstance(dest, type) and issubclass(dest, _TensorDict): + if isinstance(self, dest): + return self return dest( source=self.clone(), ) @@ -2744,6 +2748,8 @@ def pin_memory(self) -> _TensorDict: def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: if isinstance(dest, type) and issubclass(dest, _TensorDict): + if isinstance(self, dest): + return self return dest(source=self, batch_size=self.batch_size) elif isinstance(dest, (torch.device, str, int)): dest = torch.device(dest) @@ -3177,6 +3183,8 @@ def __repr__(self) -> str: def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs): if isinstance(dest, type) and issubclass(dest, _TensorDict): + if isinstance(self, dest): + return self td = dest( source=TensorDict(self.to_dict(), batch_size=self.batch_size), **kwargs, @@ -3482,6 +3490,8 @@ def del_(self, key: str) -> _CustomOpTensorDict: def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: if isinstance(dest, type) and issubclass(dest, _TensorDict): + if isinstance(self, dest): + return self return dest(source=self.contiguous().clone()) elif isinstance(dest, (torch.device, str, int)): if self._device_safe() is not None and torch.device(dest) == self.device: diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 758c9bcfa03..d16307722c1 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -883,11 +883,32 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens class LSTMNet(nn.Module): - """ - An embedder for an LSTM followed by an MLP. + """An embedder for an LSTM preceded by an MLP. + The forward method returns the hidden states of the current state (input hidden states) and the output, as the environment returns the 'observation' and 'next_observation'. + Because the LSTM kernel only returns the last hidden state, hidden states + are padded with zeros such that they have the right size to be stored in a + TensorDict of size [batch x time_steps]. + + If a 2D tensor is provided as input, it is assumed that it is a batch of data + with only one time step. This means that we explicitely assume that users will + unsqueeze inputs of a single batch with multiple time steps. + + Examples: + >>> batch = 7 + >>> time_steps = 6 + >>> in_features = 4 + >>> net = LSTMNet( + ... out_features, + ... {"input_size": hidden_size, "hidden_size": hidden_size}, + ... {"out_features": hidden_size}, + ... ) + >>> # test single step vs multi-step + >>> x = torch.randn(batch, time_steps, in_features) + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + """ def __init__(self, out_features, lstm_kwargs: Dict, mlp_kwargs: Dict) -> None: @@ -903,14 +924,19 @@ def _lstm( hidden0_in: Optional[torch.Tensor] = None, hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - squeeze = False + squeeze0 = False + squeeze1 = False + if input.ndimension() == 1: + squeeze0 = True + input = input.unsqueeze(0).contiguous() + if input.ndimension() == 2: - squeeze = True + squeeze1 = True input = input.unsqueeze(1).contiguous() batch, steps = input.shape[:2] if hidden1_in is None and hidden0_in is None: - shape = (batch, steps) if not squeeze else (batch,) + shape = (batch, steps) if not squeeze1 else (batch,) hidden0_in, hidden1_in = [ torch.zeros( *shape, @@ -925,9 +951,12 @@ def _lstm( raise RuntimeError( f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" ) + elif squeeze0: + hidden0_in = hidden0_in.unsqueeze(0) + hidden1_in = hidden1_in.unsqueeze(0) # we only need the first hidden state - if not squeeze: + if not squeeze1: _hidden0_in = hidden0_in[:, 0] _hidden1_in = hidden1_in[:, 0] else: @@ -944,9 +973,10 @@ def _lstm( y = self.linear(y0) out = [y, hidden0_in, hidden1_in, *hidden] - if squeeze: + if squeeze1: + # squeezes time out[0] = out[0].squeeze(1) - else: + if not squeeze1: # we pad the hidden states with zero to make tensordict happy for i in range(3, 5): out[i] = torch.stack( @@ -954,6 +984,8 @@ def _lstm( + [out[i]], 1, ) + if squeeze0: + out = [_out.squeeze(0) for _out in out] return tuple(out) def forward( diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 1d03a553313..cc2ba751f2b 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -220,6 +220,7 @@ def make_trainer( "post_steps_log", recorder_obj, ) + recorder_obj(None) recorder_obj_explore = Recorder( record_frames=args.record_frames, frame_skip=args.frame_skip, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index e763f551d68..7ffcbb7dee4 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -631,6 +631,7 @@ def update_reward_stats(self, batch: _TensorDict) -> None: self._update_has_been_called = True def normalize_reward(self, tensordict: _TensorDict) -> _TensorDict: + tensordict = tensordict.to_tensordict() # make sure it is not a SubTensorDict reward = tensordict.get("reward") reward = reward - self._reward_stats["mean"].to(tensordict.device) reward = reward / self._reward_stats["std"].to(tensordict.device)