Skip to content

Commit

Permalink
[Versioning] Deprecations for 0.4 (#2109)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 25, 2024
1 parent 93e9e30 commit 36e1309
Show file tree
Hide file tree
Showing 36 changed files with 243 additions and 1,098 deletions.
1 change: 0 additions & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ Regular modules
MLP
ConvNet
Conv3dNet
LSTMNet
SqueezeLayer
Squeeze2dLayer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def _create_and_launch_data_collectors(self) -> None:


class ReplayBufferNode(RemoteTensorDictReplayBuffer):
"""Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` means all of it's public methods are remotely invokable using `torch.rpc`.
Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures.
"""Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer`
means all of it's public methods are remotely invokable using `torch.rpc`.
Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation
cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures.
Args:
capacity (int): the maximum number of elements that can be stored in the replay buffer.
Expand Down
4 changes: 2 additions & 2 deletions examples/memmap/memmap_speed_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import configargparse
import torch
import torch.distributed.rpc as rpc
from tensordict import MemmapTensor
from tensordict import MemoryMappedTensor

parser = configargparse.ArgumentParser()
parser.add_argument("--rank", default=-1, type=int)
Expand Down Expand Up @@ -59,7 +59,7 @@ def op_on_tensor(idx):
# create tensor
tensor = torch.zeros(10000, 10000)
if tensortype == "memmap":
tensor = MemmapTensor(tensor)
tensor = MemoryMappedTensor.from_tensor(tensor)
elif tensortype == "tensor":
pass
else:
Expand Down
156 changes: 156 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import os
Expand All @@ -18,6 +19,7 @@
import torch.cuda

from tensordict import tensorclass, TensorDict
from torch import nn
from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator
from torchrl.data.utils import CloudpickleWrapper

Expand All @@ -35,6 +37,8 @@
# Specified for test_utils.py
__version__ = "0.3"

from torchrl.modules import MLP


def CARTPOLE_VERSIONED():
# load gym
Expand Down Expand Up @@ -498,3 +502,155 @@ def new_func(*args, **kwargs):
return func(*args, **kwargs)

return CloudpickleWrapper(new_func)


class LSTMNet(nn.Module):
"""An embedder for an LSTM preceded by an MLP.
The forward method returns the hidden states of the current state
(input hidden states) and the output, as
the environment returns the 'observation' and 'next_observation'.
Because the LSTM kernel only returns the last hidden state, hidden states
are padded with zeros such that they have the right size to be stored in a
TensorDict of size [batch x time_steps].
If a 2D tensor is provided as input, it is assumed that it is a batch of data
with only one time step. This means that we explicitely assume that users will
unsqueeze inputs of a single batch with multiple time steps.
Args:
out_features (int): number of output features.
lstm_kwargs (dict): the keyword arguments for the
:class:`~torch.nn.LSTM` layer.
mlp_kwargs (dict): the keyword arguments for the
:class:`~torchrl.modules.MLP` layer.
device (torch.device, optional): the device where the module should
be instantiated.
Keyword Args:
lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that
indeicates where the LSTM class is to be retrieved. The ``"torchrl"``
backend (:class:`~torchrl.modules.LSTM`) is slower but works with
:func:`~torch.vmap` and should work with :func:`~torch.compile`.
Defaults to ``"torch"``.
Examples:
>>> batch = 7
>>> time_steps = 6
>>> in_features = 4
>>> out_features = 10
>>> hidden_size = 5
>>> net = LSTMNet(
... out_features,
... {"input_size": hidden_size, "hidden_size": hidden_size},
... {"out_features": hidden_size},
... )
>>> # test single step vs multi-step
>>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
>>> x = torch.randn(batch, in_features) # 2 dims = single step
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
"""

def __init__(
self,
out_features: int,
lstm_kwargs,
mlp_kwargs,
device=None,
*,
lstm_backend: str | None = None,
) -> None:
super().__init__()
lstm_kwargs.update({"batch_first": True})
self.mlp = MLP(device=device, **mlp_kwargs)
if lstm_backend is None:
lstm_backend = "torch"
self.lstm_backend = lstm_backend
if self.lstm_backend == "torch":
LSTM = nn.LSTM
else:
from torchrl.modules.tensordict_module.rnn import LSTM
self.lstm = LSTM(device=device, **lstm_kwargs)
self.linear = nn.LazyLinear(out_features, device=device)

def _lstm(
self,
input: torch.Tensor,
hidden0_in: torch.Tensor | None = None,
hidden1_in: torch.Tensor | None = None,
):
squeeze0 = False
squeeze1 = False
if input.ndimension() == 1:
squeeze0 = True
input = input.unsqueeze(0).contiguous()

if input.ndimension() == 2:
squeeze1 = True
input = input.unsqueeze(1).contiguous()
batch, steps = input.shape[:2]

if hidden1_in is None and hidden0_in is None:
shape = (batch, steps) if not squeeze1 else (batch,)
hidden0_in, hidden1_in = [
torch.zeros(
*shape,
self.lstm.num_layers,
self.lstm.hidden_size,
device=input.device,
dtype=input.dtype,
)
for _ in range(2)
]
elif hidden1_in is None or hidden0_in is None:
raise RuntimeError(
f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}"
)
elif squeeze0:
hidden0_in = hidden0_in.unsqueeze(0)
hidden1_in = hidden1_in.unsqueeze(0)

# we only need the first hidden state
if not squeeze1:
_hidden0_in = hidden0_in[:, 0]
_hidden1_in = hidden1_in[:, 0]
else:
_hidden0_in = hidden0_in
_hidden1_in = hidden1_in
hidden = (
_hidden0_in.transpose(-3, -2).contiguous(),
_hidden1_in.transpose(-3, -2).contiguous(),
)

y0, hidden = self.lstm(input, hidden)
# dim 0 in hidden is num_layers, but that will conflict with tensordict
hidden = tuple(_h.transpose(0, 1) for _h in hidden)
y = self.linear(y0)

out = [y, hidden0_in, hidden1_in, *hidden]
if squeeze1:
# squeezes time
out[0] = out[0].squeeze(1)
if not squeeze1:
# we pad the hidden states with zero to make tensordict happy
for i in range(3, 5):
out[i] = torch.stack(
[torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)]
+ [out[i]],
1,
)
if squeeze0:
out = [_out.squeeze(0) for _out in out]
return tuple(out)

def forward(
self,
input: torch.Tensor,
hidden0_in: torch.Tensor | None = None,
hidden1_in: torch.Tensor | None = None,
):
input = self.mlp(input)
return self._lstm(input, hidden0_in, hidden1_in)
2 changes: 0 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def __init__(
):
super().__init__(
device=kwargs.pop("device", "cpu"),
dtype=torch.get_default_dtype(),
allow_done_after_reset=kwargs.pop("allow_done_after_reset", False),
)
self.set_seed(seed)
Expand Down Expand Up @@ -926,7 +925,6 @@ def __init__(
super().__init__(
world_model,
device=device,
dtype=dtype,
batch_size=batch_size,
)
self.observation_spec = CompositeSpec(
Expand Down
16 changes: 8 additions & 8 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
out_keys=[("data", "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"min": action_spec.space.minimum,
"max": action_spec.space.maximum,
"min": action_spec.space.low,
"max": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand All @@ -86,8 +86,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
out_keys=[("data", "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"min": action_spec.space.minimum,
"max": action_spec.space.maximum,
"min": action_spec.space.low,
"max": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand Down Expand Up @@ -130,8 +130,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
out_keys=[("data", "action")],
distribution_class=TanhNormal,
distribution_kwargs={
"min": action_spec.space.minimum,
"max": action_spec.space.maximum,
"min": action_spec.space.low,
"max": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand All @@ -153,8 +153,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
out_keys=[("data", "action")],
distribution_class=TanhNormal,
distribution_kwargs={
"min": action_spec.space.minimum,
"max": action_spec.space.maximum,
"min": action_spec.space.low,
"max": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand Down
3 changes: 2 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
PONG_VERSIONED,
retry,
Expand Down Expand Up @@ -74,7 +75,7 @@
PARTIAL_MISSING_ERR,
RandomPolicy,
)
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule

# torch.set_default_dtype(torch.double)
IS_WINDOWS = sys.platform == "win32"
Expand Down
Loading

0 comments on commit 36e1309

Please sign in to comment.