From 36e13099885ff7fcf25a8748a58a4e5b87995ba2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Apr 2024 11:20:14 +0100 Subject: [PATCH] [Versioning] Deprecations for 0.4 (#2109) --- docs/source/reference/modules.rst | 1 - .../distributed_replay_buffer.py | 6 +- examples/memmap/memmap_speed_distributed.py | 4 +- test/_utils_internal.py | 156 +++++++++++ test/mocking_classes.py | 2 - test/test_actors.py | 16 +- test/test_collector.py | 3 +- test/test_helpers.py | 123 +-------- test/test_modules.py | 135 --------- test/test_rb.py | 24 +- test/test_shared.py | 6 +- test/test_transforms.py | 45 ++- torchrl/data/datasets/vd4rl.py | 2 +- torchrl/data/replay_buffers/storages.py | 58 +--- torchrl/data/tensor_specs.py | 20 -- torchrl/envs/common.py | 7 +- torchrl/envs/gym_like.py | 10 - torchrl/envs/model_based/common.py | 5 +- torchrl/envs/model_based/dreamer.py | 6 +- torchrl/envs/transforms/transforms.py | 12 +- torchrl/modules/__init__.py | 1 - torchrl/modules/models/__init__.py | 1 - torchrl/modules/models/models.py | 157 ----------- torchrl/modules/tensordict_module/actors.py | 20 +- torchrl/modules/tensordict_module/common.py | 16 +- torchrl/objectives/a2c.py | 46 ---- torchrl/objectives/common.py | 6 - torchrl/objectives/dqn.py | 24 +- torchrl/objectives/multiagent/qmixer.py | 12 +- torchrl/objectives/ppo.py | 46 ---- torchrl/objectives/reinforce.py | 46 ---- torchrl/objectives/value/functional.py | 6 +- torchrl/trainers/helpers/__init__.py | 4 +- torchrl/trainers/helpers/losses.py | 43 --- torchrl/trainers/helpers/models.py | 260 +----------------- torchrl/trainers/trainers.py | 12 +- 36 files changed, 243 insertions(+), 1098 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index c42376e4948..c12bba985d6 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -315,7 +315,6 @@ Regular modules MLP ConvNet Conv3dNet - LSTMNet SqueezeLayer Squeeze2dLayer diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index 0cb9aaaffbd..c7504fbf8ee 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -149,8 +149,10 @@ def _create_and_launch_data_collectors(self) -> None: class ReplayBufferNode(RemoteTensorDictReplayBuffer): - """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` means all of it's public methods are remotely invokable using `torch.rpc`. - Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. + """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` + means all of it's public methods are remotely invokable using `torch.rpc`. + Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation + cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. Args: capacity (int): the maximum number of elements that can be stored in the replay buffer. diff --git a/examples/memmap/memmap_speed_distributed.py b/examples/memmap/memmap_speed_distributed.py index 61c100e0e4a..ec324e7cc55 100644 --- a/examples/memmap/memmap_speed_distributed.py +++ b/examples/memmap/memmap_speed_distributed.py @@ -9,7 +9,7 @@ import configargparse import torch import torch.distributed.rpc as rpc -from tensordict import MemmapTensor +from tensordict import MemoryMappedTensor parser = configargparse.ArgumentParser() parser.add_argument("--rank", default=-1, type=int) @@ -59,7 +59,7 @@ def op_on_tensor(idx): # create tensor tensor = torch.zeros(10000, 10000) if tensortype == "memmap": - tensor = MemmapTensor(tensor) + tensor = MemoryMappedTensor.from_tensor(tensor) elif tensortype == "tensor": pass else: diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 6c267768044..e43c0ff2ecf 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import contextlib import os @@ -18,6 +19,7 @@ import torch.cuda from tensordict import tensorclass, TensorDict +from torch import nn from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator from torchrl.data.utils import CloudpickleWrapper @@ -35,6 +37,8 @@ # Specified for test_utils.py __version__ = "0.3" +from torchrl.modules import MLP + def CARTPOLE_VERSIONED(): # load gym @@ -498,3 +502,155 @@ def new_func(*args, **kwargs): return func(*args, **kwargs) return CloudpickleWrapper(new_func) + + +class LSTMNet(nn.Module): + """An embedder for an LSTM preceded by an MLP. + + The forward method returns the hidden states of the current state + (input hidden states) and the output, as + the environment returns the 'observation' and 'next_observation'. + + Because the LSTM kernel only returns the last hidden state, hidden states + are padded with zeros such that they have the right size to be stored in a + TensorDict of size [batch x time_steps]. + + If a 2D tensor is provided as input, it is assumed that it is a batch of data + with only one time step. This means that we explicitely assume that users will + unsqueeze inputs of a single batch with multiple time steps. + + Args: + out_features (int): number of output features. + lstm_kwargs (dict): the keyword arguments for the + :class:`~torch.nn.LSTM` layer. + mlp_kwargs (dict): the keyword arguments for the + :class:`~torchrl.modules.MLP` layer. + device (torch.device, optional): the device where the module should + be instantiated. + + Keyword Args: + lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that + indeicates where the LSTM class is to be retrieved. The ``"torchrl"`` + backend (:class:`~torchrl.modules.LSTM`) is slower but works with + :func:`~torch.vmap` and should work with :func:`~torch.compile`. + Defaults to ``"torch"``. + + Examples: + >>> batch = 7 + >>> time_steps = 6 + >>> in_features = 4 + >>> out_features = 10 + >>> hidden_size = 5 + >>> net = LSTMNet( + ... out_features, + ... {"input_size": hidden_size, "hidden_size": hidden_size}, + ... {"out_features": hidden_size}, + ... ) + >>> # test single step vs multi-step + >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + >>> x = torch.randn(batch, in_features) # 2 dims = single step + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + + """ + + def __init__( + self, + out_features: int, + lstm_kwargs, + mlp_kwargs, + device=None, + *, + lstm_backend: str | None = None, + ) -> None: + super().__init__() + lstm_kwargs.update({"batch_first": True}) + self.mlp = MLP(device=device, **mlp_kwargs) + if lstm_backend is None: + lstm_backend = "torch" + self.lstm_backend = lstm_backend + if self.lstm_backend == "torch": + LSTM = nn.LSTM + else: + from torchrl.modules.tensordict_module.rnn import LSTM + self.lstm = LSTM(device=device, **lstm_kwargs) + self.linear = nn.LazyLinear(out_features, device=device) + + def _lstm( + self, + input: torch.Tensor, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ): + squeeze0 = False + squeeze1 = False + if input.ndimension() == 1: + squeeze0 = True + input = input.unsqueeze(0).contiguous() + + if input.ndimension() == 2: + squeeze1 = True + input = input.unsqueeze(1).contiguous() + batch, steps = input.shape[:2] + + if hidden1_in is None and hidden0_in is None: + shape = (batch, steps) if not squeeze1 else (batch,) + hidden0_in, hidden1_in = [ + torch.zeros( + *shape, + self.lstm.num_layers, + self.lstm.hidden_size, + device=input.device, + dtype=input.dtype, + ) + for _ in range(2) + ] + elif hidden1_in is None or hidden0_in is None: + raise RuntimeError( + f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" + ) + elif squeeze0: + hidden0_in = hidden0_in.unsqueeze(0) + hidden1_in = hidden1_in.unsqueeze(0) + + # we only need the first hidden state + if not squeeze1: + _hidden0_in = hidden0_in[:, 0] + _hidden1_in = hidden1_in[:, 0] + else: + _hidden0_in = hidden0_in + _hidden1_in = hidden1_in + hidden = ( + _hidden0_in.transpose(-3, -2).contiguous(), + _hidden1_in.transpose(-3, -2).contiguous(), + ) + + y0, hidden = self.lstm(input, hidden) + # dim 0 in hidden is num_layers, but that will conflict with tensordict + hidden = tuple(_h.transpose(0, 1) for _h in hidden) + y = self.linear(y0) + + out = [y, hidden0_in, hidden1_in, *hidden] + if squeeze1: + # squeezes time + out[0] = out[0].squeeze(1) + if not squeeze1: + # we pad the hidden states with zero to make tensordict happy + for i in range(3, 5): + out[i] = torch.stack( + [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] + + [out[i]], + 1, + ) + if squeeze0: + out = [_out.squeeze(0) for _out in out] + return tuple(out) + + def forward( + self, + input: torch.Tensor, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ): + input = self.mlp(input) + return self._lstm(input, hidden0_in, hidden1_in) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ec9cec7fabd..75769215ce5 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -112,7 +112,6 @@ def __init__( ): super().__init__( device=kwargs.pop("device", "cpu"), - dtype=torch.get_default_dtype(), allow_done_after_reset=kwargs.pop("allow_done_after_reset", False), ) self.set_seed(seed) @@ -926,7 +925,6 @@ def __init__( super().__init__( world_model, device=device, - dtype=dtype, batch_size=batch_size, ) self.observation_spec = CompositeSpec( diff --git a/test/test_actors.py b/test/test_actors.py index ddefcea274c..560566286ae 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -63,8 +63,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -86,8 +86,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -130,8 +130,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions out_keys=[("data", "action")], distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -153,8 +153,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions out_keys=[("data", "action")], distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.minimum, - "max": action_spec.space.maximum, + "min": action_spec.space.low, + "max": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, diff --git a/test/test_collector.py b/test/test_collector.py index 230ff159c28..eee42a11fc2 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -20,6 +20,7 @@ generate_seeds, get_available_devices, get_default_devices, + LSTMNet, PENDULUM_VERSIONED, PONG_VERSIONED, retry, @@ -74,7 +75,7 @@ PARTIAL_MISSING_ERR, RandomPolicy, ) -from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule +from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule # torch.set_default_dtype(torch.double) IS_WINDOWS = sys.platform == "win32" diff --git a/test/test_helpers.py b/test/test_helpers.py index 46036346de5..f468eddf6ed 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -7,11 +7,11 @@ import dataclasses import pathlib import sys - from time import sleep import pytest import torch + from _utils_internal import generate_seeds, get_default_devices from torchrl._utils import timeit @@ -38,8 +38,6 @@ FlattenObservation, TransformedEnv, ) -from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules.tensordict_module.common import _has_functorch from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.helpers.envs import ( EnvConfig, @@ -50,8 +48,6 @@ DiscreteModelConfig, DreamerConfig, make_dqn_actor, - make_redq_model, - REDQModelConfig, ) TORCH_VERSION = version.parse(torch.__version__) @@ -162,123 +158,6 @@ def test_dqn_maker( proof_environment.close() -@pytest.mark.skipif(not _has_functorch, reason="functorch not installed") -@pytest.mark.skipif(not _has_hydra, reason="No hydra library found") -@pytest.mark.skipif(not _has_gym, reason="No gym library found") -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) -@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) -@pytest.mark.parametrize("exploration", [ExplorationType.MODE, ExplorationType.RANDOM]) -def test_redq_make(device, from_pixels, gsde, exploration): - if not gsde and exploration != ExplorationType.RANDOM: - pytest.skip("no need to test this setting") - flags = list(from_pixels + gsde) - if gsde and from_pixels: - pytest.skip("gsde and from_pixels are incompatible") - - config_fields = [ - (config_field.name, config_field.type, config_field) - for config_cls in ( - EnvConfig, - REDQModelConfig, - ) - for config_field in dataclasses.fields(config_cls) - ] - - Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) - cs = ConfigStore.instance() - cs.store(name="config", node=Config) - with initialize(version_base="1.1", config_path=None): - cfg = compose(config_name="config", overrides=flags) - - env_maker = ( - ContinuousActionConvMockEnvNumpy - if from_pixels - else ContinuousActionVecMockEnv - ) - env_maker = transformed_env_constructor( - cfg, - use_env_creator=False, - custom_env_maker=env_maker, - stats={"loc": 0.0, "scale": 1.0}, - ) - proof_environment = env_maker() - - model = make_redq_model( - proof_environment, - device=device, - cfg=cfg, - ) - actor, qvalue = model - td = proof_environment.reset().to(device) - with set_exploration_type(exploration): - actor(td) - expected_keys = [ - "done", - "terminated", - "action", - "sample_log_prob", - "loc", - "scale", - "step_count", - "is_init", - ] - if len(gsde): - expected_keys += ["_eps_gSDE"] - if from_pixels: - expected_keys += [ - "hidden", - "pixels", - "pixels_orig", - ] - else: - expected_keys += ["observation_vector", "observation_orig"] - - try: - assert set(td.keys()) == set(expected_keys) - except AssertionError: - proof_environment.close() - raise - - if cfg.gSDE: - tsf_loc = actor.module[0].module[-1].module.transform(td.get("loc")) - if exploration == ExplorationType.RANDOM: - with pytest.raises(AssertionError): - torch.testing.assert_close(td.get("action"), tsf_loc) - else: - torch.testing.assert_close(td.get("action"), tsf_loc) - - qvalue(td) - expected_keys = [ - "done", - "terminated", - "action", - "sample_log_prob", - "state_action_value", - "loc", - "scale", - "step_count", - "is_init", - ] - if len(gsde): - expected_keys += ["_eps_gSDE"] - if from_pixels: - expected_keys += [ - "hidden", - "pixels", - "pixels_orig", - ] - else: - expected_keys += ["observation_vector", "observation_orig"] - try: - assert set(td.keys()) == set(expected_keys) - except AssertionError: - proof_environment.close() - raise - proof_environment.close() - del proof_environment - - @pytest.mark.parametrize("initial_seed", range(5)) def test_seed_generator(initial_seed): num_seeds = 100 diff --git a/test/test_modules.py b/test/test_modules.py index c9984e178c5..65c2d2613de 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -23,7 +23,6 @@ GRUCell, LSTM, LSTMCell, - LSTMNet, MultiAgentConvNet, MultiAgentMLP, OnlineDTActor, @@ -350,140 +349,6 @@ def test_noisy(layer_class, device, seed=0): torch.testing.assert_close(y1, y2) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("out_features", [3, 4]) -@pytest.mark.parametrize("hidden_size", [8, 9]) -@pytest.mark.parametrize("num_layers", [1, 2]) -@pytest.mark.parametrize("has_precond_hidden", [True, False]) -def test_lstm_net( - device, - out_features, - hidden_size, - num_layers, - has_precond_hidden, - double_prec_fixture, -): - torch.manual_seed(0) - batch = 5 - time_steps = 6 - in_features = 7 - net = LSTMNet( - out_features, - { - "input_size": hidden_size, - "hidden_size": hidden_size, - "num_layers": num_layers, - }, - {"out_features": hidden_size}, - device=device, - ) - # test single step vs multi-step - x = torch.randn(batch, time_steps, in_features, device=device) - x_unbind = x.unbind(1) - tds_loop = [] - if has_precond_hidden: - hidden0_out0, hidden1_out0 = torch.randn( - 2, batch, time_steps, num_layers, hidden_size, device=device - ) - hidden0_out0[:, 1:] = 0.0 - hidden1_out0[:, 1:] = 0.0 - hidden0_out = hidden0_out0[:, 0] - hidden1_out = hidden1_out0[:, 0] - else: - hidden0_out, hidden1_out = None, None - hidden0_out0, hidden1_out0 = None, None - - for _x in x_unbind: - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( - _x, hidden0_out, hidden1_out - ) - td = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [batch], - ) - tds_loop.append(td) - tds_loop = torch.stack(tds_loop, 1) - - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( - x, hidden0_out0, hidden1_out0 - ) - tds_vec = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [batch, time_steps], - ) - torch.testing.assert_close(tds_vec["y"], tds_loop["y"]) - torch.testing.assert_close( - tds_vec["hidden0_out"][:, -1], tds_loop["hidden0_out"][:, -1] - ) - torch.testing.assert_close( - tds_vec["hidden1_out"][:, -1], tds_loop["hidden1_out"][:, -1] - ) - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("out_features", [3, 5]) -@pytest.mark.parametrize("hidden_size", [3, 5]) -def test_lstm_net_nobatch(device, out_features, hidden_size): - time_steps = 6 - in_features = 4 - net = LSTMNet( - out_features, - {"input_size": hidden_size, "hidden_size": hidden_size}, - {"out_features": hidden_size}, - device=device, - ) - # test single step vs multi-step - x = torch.randn(time_steps, in_features, device=device) - x_unbind = x.unbind(0) - tds_loop = [] - hidden0_in, hidden1_in, hidden0_out, hidden1_out = [ - None, - ] * 4 - for _x in x_unbind: - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net( - _x, hidden0_out, hidden1_out - ) - td = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [], - ) - tds_loop.append(td) - tds_loop = torch.stack(tds_loop, 0) - - y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x.unsqueeze(0)) - tds_vec = TensorDict( - { - "y": y, - "hidden0_in": hidden0_in, - "hidden1_in": hidden1_in, - "hidden0_out": hidden0_out, - "hidden1_out": hidden1_out, - }, - [1, time_steps], - ).squeeze(0) - torch.testing.assert_close(tds_vec["y"], tds_loop["y"]) - torch.testing.assert_close(tds_vec["hidden0_out"][-1], tds_loop["hidden0_out"][-1]) - torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) - - @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch_size", [3, 5]) class TestPlanner: diff --git a/test/test_rb.py b/test/test_rb.py index 9582738617c..36e950f5ed8 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -613,19 +613,21 @@ class TC: ) else: raise NotImplementedError + + if ( + storage_type is LazyMemmapStorage + and device_storage != "auto" + and device_storage.type != "cpu" + ): + with pytest.raises(ValueError, match="Memory map device other than CPU"): + storage_type(max_size=10, device=device_storage) + return storage = storage_type(max_size=10, device=device_storage) - if device_storage == "auto": - device_storage = device_data - if storage_type is LazyMemmapStorage and device_storage.type == "cuda": - with pytest.warns( - DeprecationWarning, match="Support for Memmap device other than CPU" - ): - # this is rather brittle and will fail with some indices - # when both device (storage and data) don't match (eg, range()) - storage.set(0, data) + storage.set(0, data) + if device_storage != "auto": + assert storage.get(0).device.type == device_storage.type else: - storage.set(0, data) - assert storage.get(0).device.type == device_storage.type + assert storage.get(0).device.type == storage.device.type @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) diff --git a/test/test_shared.py b/test/test_shared.py index 912f230e8cf..bc2638269e6 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -64,7 +64,7 @@ def test_shared(self, indexing_method): batch_size=[], ).share_memory_() elif indexing_method == 1: - subtd = td.get_sub_tensordict(0) + subtd = td._get_sub_tensordict(0) elif indexing_method == 2: subtd = td[0] else: @@ -182,14 +182,14 @@ def test_memmap(idx, dtype, large_scale=False): torchrl_logger.info("\nTesting writing to TD") for i in range(2): t0 = time.time() - sub_td_sm = td_sm.get_sub_tensordict(idx) + sub_td_sm = td_sm._get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: torchrl_logger.info(f"sm td: {time.time() - t0:4.4f} sec") torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a")) t0 = time.time() - sub_td_sm = td_memmap.get_sub_tensordict(idx) + sub_td_sm = td_memmap._get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: torchrl_logger.info(f"memmap td: {time.time() - t0:4.4f} sec") diff --git a/test/test_transforms.py b/test/test_transforms.py index c9d2fb8c031..4396ea79c41 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -785,7 +785,7 @@ def test_transform_env_clone(self): @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) + @pytest.mark.parametrize("padding", ["constant", "same"]) def test_transform_model(self, dim, N, padding): # test equivalence between transforms within an env and within a rb key1 = "observation" @@ -838,7 +838,7 @@ def test_transform_model(self, dim, N, padding): @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) + @pytest.mark.parametrize("padding", ["same", "constant"]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, dim, N, padding, rbclass): # test equivalence between transforms within an env and within a rb @@ -870,7 +870,7 @@ def test_transform_rb(self, dim, N, padding, rbclass): @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) - @pytest.mark.parametrize("padding", ["same", "zeros", "constant"]) + @pytest.mark.parametrize("padding", ["same", "constant"]) def test_transform_as_inverse(self, dim, N, padding): # test equivalence between transforms within an env and within a rb in_keys = ["observation", ("next", "observation")] @@ -987,7 +987,7 @@ def test_transform_no_env(self, device, d, batch_size, dim, N): assert v1 is not v2 @pytest.mark.skipif(not _has_gym, reason="gym required for this test") - @pytest.mark.parametrize("padding", ["zeros", "constant", "same"]) + @pytest.mark.parametrize("padding", ["constant", "same"]) @pytest.mark.parametrize("envtype", ["gym", "conv"]) def test_tranform_offline_against_online(self, padding, envtype): torch.manual_seed(0) @@ -1027,10 +1027,7 @@ def test_tranform_offline_against_online(self, padding, envtype): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch_size", [(), (1,), (1, 2)]) @pytest.mark.parametrize("d", range(2, 3)) - @pytest.mark.parametrize( - "dim", - [-3], - ) + @pytest.mark.parametrize("dim", [-3]) @pytest.mark.parametrize("N", [2, 4]) def test_transform_compose(self, device, d, batch_size, dim, N): key1 = "first key" @@ -4177,11 +4174,11 @@ def test_observationnorm( ) observation_spec = on.transform_observation_spec(observation_spec) if standard_normal: - assert (observation_spec.space.minimum == -loc / scale).all() - assert (observation_spec.space.maximum == (1 - loc) / scale).all() + assert (observation_spec.space.low == -loc / scale).all() + assert (observation_spec.space.high == (1 - loc) / scale).all() else: - assert (observation_spec.space.minimum == loc).all() - assert (observation_spec.space.maximum == scale + loc).all() + assert (observation_spec.space.low == loc).all() + assert (observation_spec.space.high == scale + loc).all() else: observation_spec = CompositeSpec( @@ -5097,9 +5094,9 @@ def test_keys_length_errors(self, in_keys, reset_keys, out_keys, batch=10): f"Could not match the env reset_keys {reset_keys} with the in_keys {in_keys}" ), ): - t.reset(td) + t._reset(td, td.empty()) else: - t.reset(td) + t._reset(td, td.empty()) class TestReward2Go(TransformBase): @@ -6149,8 +6146,8 @@ def test_transform_no_env(self, keys, batch, device): observation_spec ) assert observation_spec.shape == torch.Size([3, 16, 16]) - assert (observation_spec.space.minimum == 0).all() - assert (observation_spec.space.maximum == 1).all() + assert (observation_spec.space.low == 0).all() + assert (observation_spec.space.high == 1).all() else: observation_spec = CompositeSpec( { @@ -6198,8 +6195,8 @@ def test_transform_compose(self, keys, batch, device): observation_spec ) assert observation_spec.shape == torch.Size([3, 16, 16]) - assert (observation_spec.space.minimum == 0).all() - assert (observation_spec.space.maximum == 1).all() + assert (observation_spec.space.low == 0).all() + assert (observation_spec.space.high == 1).all() else: observation_spec = CompositeSpec( { @@ -8039,14 +8036,14 @@ def test_independent_reward_specs_from_shared_env(self): t1_reward_spec = t1.reward_spec t2_reward_spec = t2.reward_spec - assert t1_reward_spec.space.minimum == 0 - assert t1_reward_spec.space.maximum == 4 + assert t1_reward_spec.space.low == 0 + assert t1_reward_spec.space.high == 4 - assert t2_reward_spec.space.minimum == -2 - assert t2_reward_spec.space.maximum == 2 + assert t2_reward_spec.space.low == -2 + assert t2_reward_spec.space.high == 2 - assert base_env.reward_spec.space.minimum == -np.inf - assert base_env.reward_spec.space.maximum == np.inf + assert base_env.reward_spec.space.low == -np.inf + assert base_env.reward_spec.space.high == np.inf def test_allow_done_after_reset(self): base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index 961702aca0c..7290a714155 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -368,7 +368,7 @@ def _process_data(cls, td: TensorDict): td.set("truncated", torch.zeros_like(td.get(("next", "truncated")))) pixels = td.get("pixels") - subtd = td.get_sub_tensordict(slice(0, -1)) + subtd = td._get_sub_tensordict(slice(0, -1)) subtd.set(("next", "pixels"), pixels[1:], inplace=True) state = td.get("state", None) if state is not None: diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index a1ada2eb72e..6058de290e3 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -24,22 +24,14 @@ TensorDict, TensorDictBase, ) -from tensordict.memmap import MemmapTensor, MemoryMappedTensor +from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _STRDTYPE2DTYPE from torch import multiprocessing as mp - from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten -from torchrl._utils import _CKPT_BACKEND, implement_for, logger as torchrl_logger +from torchrl._utils import implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES -try: - from torchsnapshot.serialization import tensor_from_memoryview - - _has_ts = True -except ImportError: - _has_ts = False - SINGLE_TENSOR_BUFFER_NAME = os.environ.get( "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_" ) @@ -467,8 +459,8 @@ def loads(self, path): _storage = TensorDict.load_memmap(path) if not self.initialized: # this should not be reached if is_pytree=True - self._storage = _storage - self.initialized = True + self._init(_storage[0]) + self._storage.update_(_storage) else: self._storage.copy_(_storage) self._len = _len @@ -1026,7 +1018,12 @@ def __init__( self.scratch_dir = str(scratch_dir) if self.scratch_dir[-1] != "/": self.scratch_dir += "/" - self.device = torch.device(device) if device != "auto" else device + self.device = torch.device(device) if device != "auto" else torch.device("cpu") + if self.device.type != "cpu": + raise ValueError( + "Memory map device other than CPU isn't supported. To cast your data to the desired device, " + "use `buffer.append_transform(lambda x: x.to(device)` or a similar transform." + ) self._len = 0 def state_dict(self) -> Dict[str, Any]: @@ -1093,11 +1090,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if self.device == "auto": self.device = data.device if self.device.type != "cpu": - warnings.warn( - "Support for Memmap device other than CPU will be deprecated in v0.4.0. " - "Using a 'cuda' device may be suboptimal.", - category=DeprecationWarning, - ) + raise RuntimeError("Support for Memmap device other than CPU is deprecated") def max_size_along_dim0(data_shape): if self.ndim > 1: @@ -1128,17 +1121,7 @@ def max_size_along_dim0(data_shape): def get(self, index: Union[int, Sequence[int], slice]) -> Any: result = super().get(index) - - # to be deprecated in v0.4 - def map_device(tensor): - if tensor.device != self.device: - return tensor.to(self.device, non_blocking=False) - return tensor - - if is_tensor_collection(result): - return map_device(result) - else: - return tree_map(map_device, result) + return result class StorageEnsemble(Storage): @@ -1301,25 +1284,10 @@ def __repr__(self): # Utils -def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: - if _CKPT_BACKEND == "torchsnapshot" and not _has_ts: - raise ImportError( - "the checkpointing backend is set to torchsnapshot but the library is not installed. Consider installing the library or switch to another backend. " - f"Supported backends are {_CKPT_BACKEND.backends}" - ) +def _mem_map_tensor_as_tensor(mem_map_tensor) -> torch.Tensor: if isinstance(mem_map_tensor, torch.Tensor): # This will account for MemoryMappedTensors return mem_map_tensor - if _CKPT_BACKEND == "torchsnapshot": - # TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor - # as a Tensor for saving and loading purposes. This doesn't incur any copy. - return tensor_from_memoryview( - dtype=mem_map_tensor.dtype, - shape=list(mem_map_tensor.shape), - mv=memoryview(mem_map_tensor._memmap_array), - ) - elif _CKPT_BACKEND == "torch": - return mem_map_tensor._tensor def _collate_list_tensordict(x): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 105c13214e0..3228afdfd8b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -382,22 +382,6 @@ def high(self, value): self.device = value.device self._high = value.cpu() - @property - def minimum(self): - warnings.warn( - f"{type(self)}.minimum is going to be deprecated in favour of {type(self)}.low in v0.4.0", - category=DeprecationWarning, - ) - return self._low.to(self.device) - - @property - def maximum(self): - warnings.warn( - f"{type(self)}.maximum is going to be deprecated in favour of {type(self)}.high in v0.4.0", - category=DeprecationWarning, - ) - return self._high.to(self.device) - @low.setter def low(self, value): self.device = value.device @@ -1596,10 +1580,6 @@ class BoundedTensorSpec(TensorSpec): """ # SPEC_HANDLED_FUNCTIONS = {} - DEPRECATED_KWARGS = ( - "The `minimum` and `maximum` keyword arguments are now " - "deprecated in favour of `low` and `high` in v0.4.0." - ) CONFLICTING_KWARGS = ( "The keyword arguments {} and {} conflict. Only one of these can be passed." ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8712c74340a..0715cabf29f 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -9,7 +9,7 @@ import functools import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import numpy as np import torch @@ -202,7 +202,6 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): on that device and it is expected that all inputs and outputs will live on that device. Defaults to ``None``. - dtype (deprecated): dtype of the observations. Will be deprecated in v0.4. batch_size (torch.Size or equivalent, optional): batch-size of the environment. Corresponds to the leading dimension of all the input and output tensordicts the environment reads and writes. Defaults to an empty batch-size. @@ -341,7 +340,6 @@ def __init__( self, *, device: DEVICE_TYPING = None, - dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, @@ -365,7 +363,6 @@ def __init__( ) super().__init__() - self.dtype = dtype_map.get(dtype, dtype) if "is_closed" not in self.__dir__(): self.is_closed = True if batch_size is not None: @@ -2983,7 +2980,6 @@ class _EnvWrapper(EnvBase): def __init__( self, *args, - dtype: Optional[np.dtype] = None, device: DEVICE_TYPING = NO_DEFAULT, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, @@ -3001,7 +2997,6 @@ def __init__( device = torch.device("cpu") super().__init__( device=device, - dtype=dtype, batch_size=batch_size, allow_done_after_reset=allow_done_after_reset, ) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index b27c1f795a2..9cbec79211d 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -529,13 +529,3 @@ def __repr__(self) -> str: @property def info_dict_reader(self): return self._info_dict_reader - - @info_dict_reader.setter - def info_dict_reader(self, value: callable): - warnings.warn( - f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " - f"This method will append a reader to the list of existing readers (if any). " - f"Setting info_dict_reader directly will be deprecated in v0.4.0.", - category=DeprecationWarning, - ) - self._info_dict_reader.append(value) diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index c1940f75a8f..f6b3f97cd4a 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -5,9 +5,8 @@ import abc import warnings -from typing import List, Optional, Union +from typing import List, Optional -import numpy as np import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -117,13 +116,11 @@ def __init__( params: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None, device: DEVICE_TYPING = "cpu", - dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, ): super(ModelBasedEnvBase, self).__init__( device=device, - dtype=dtype, batch_size=batch_size, run_type_checks=run_type_checks, ) diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index f44c4aa025c..5609861c75f 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -3,9 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple, Union +from typing import Optional, Tuple -import numpy as np import torch from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -27,11 +26,10 @@ def __init__( belief_shape: Tuple[int, ...], obs_decoder: TensorDictModule = None, device: DEVICE_TYPING = "cpu", - dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, ): super(DreamerEnv, self).__init__( - world_model, device=device, dtype=dtype, batch_size=batch_size + world_model, device=device, batch_size=batch_size ) self.obs_decoder = obs_decoder self.prior_shape = prior_shape diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ccab829d480..acec170c7d0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -236,10 +236,6 @@ def out_keys_inv(self, value): value = [unravel_key(val) for val in value] self._out_keys_inv = value - def reset(self, tensordict): - warnings.warn("Transform.reset public method will be derpecated in v0.4.0.") - return self._reset(tensordict, tensordict_reset=tensordict) - def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: @@ -2836,13 +2832,7 @@ def __init__( if padding not in self.ACCEPTED_PADDING: raise ValueError(f"padding must be one of {self.ACCEPTED_PADDING}") if padding == "zeros": - warnings.warn( - "Padding option 'zeros' will be deprecated in v0.4.0. " - "Please use 'constant' padding with padding_value 0 instead.", - category=DeprecationWarning, - ) - padding = "constant" - padding_value = 0 + raise RuntimeError("Padding option 'zeros' will is deprecated") self.padding = padding self.padding_value = padding_value for in_key in self.in_keys: diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index a987e701672..4a3c5e716e8 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -29,7 +29,6 @@ DreamerActor, DTActor, DuelingCnnDQNet, - LSTMNet, MLP, MultiAgentConvNet, MultiAgentMLP, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 7b11cae9515..fb0cc0135b8 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -27,7 +27,6 @@ DTActor, DuelingCnnDQNet, DuelingMlpDQNet, - LSTMNet, MLP, OnlineDTActor, ) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 8e6fc75e12e..23c229c6524 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -6,7 +6,6 @@ import dataclasses -import warnings from copy import deepcopy from numbers import Number from typing import Callable, Dict, List, Sequence, Tuple, Type, Union @@ -1481,162 +1480,6 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens return value -class LSTMNet(nn.Module): - """An embedder for an LSTM preceded by an MLP. - - The forward method returns the hidden states of the current state - (input hidden states) and the output, as - the environment returns the 'observation' and 'next_observation'. - - Because the LSTM kernel only returns the last hidden state, hidden states - are padded with zeros such that they have the right size to be stored in a - TensorDict of size [batch x time_steps]. - - If a 2D tensor is provided as input, it is assumed that it is a batch of data - with only one time step. This means that we explicitely assume that users will - unsqueeze inputs of a single batch with multiple time steps. - - Args: - out_features (int): number of output features. - lstm_kwargs (dict): the keyword arguments for the - :class:`~torch.nn.LSTM` layer. - mlp_kwargs (dict): the keyword arguments for the - :class:`~torchrl.modules.MLP` layer. - device (torch.device, optional): the device where the module should - be instantiated. - - Keyword Args: - lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that - indeicates where the LSTM class is to be retrieved. The ``"torchrl"`` - backend (:class:`~torchrl.modules.LSTM`) is slower but works with - :func:`~torch.vmap` and should work with :func:`~torch.compile`. - Defaults to ``"torch"``. - - Examples: - >>> batch = 7 - >>> time_steps = 6 - >>> in_features = 4 - >>> out_features = 10 - >>> hidden_size = 5 - >>> net = LSTMNet( - ... out_features, - ... {"input_size": hidden_size, "hidden_size": hidden_size}, - ... {"out_features": hidden_size}, - ... ) - >>> # test single step vs multi-step - >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step - >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) - >>> x = torch.randn(batch, in_features) # 2 dims = single step - >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) - - """ - - def __init__( - self, - out_features: int, - lstm_kwargs: Dict, - mlp_kwargs: Dict, - device: DEVICE_TYPING | None = None, - *, - lstm_backend: str | None = None, - ) -> None: - warnings.warn( - "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", - category=DeprecationWarning, - ) - super().__init__() - lstm_kwargs.update({"batch_first": True}) - self.mlp = MLP(device=device, **mlp_kwargs) - if lstm_backend is None: - lstm_backend = "torch" - self.lstm_backend = lstm_backend - if self.lstm_backend == "torch": - LSTM = nn.LSTM - else: - from torchrl.modules.tensordict_module.rnn import LSTM - self.lstm = LSTM(device=device, **lstm_kwargs) - self.linear = nn.LazyLinear(out_features, device=device) - - def _lstm( - self, - input: torch.Tensor, - hidden0_in: torch.Tensor | None = None, - hidden1_in: torch.Tensor | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - squeeze0 = False - squeeze1 = False - if input.ndimension() == 1: - squeeze0 = True - input = input.unsqueeze(0).contiguous() - - if input.ndimension() == 2: - squeeze1 = True - input = input.unsqueeze(1).contiguous() - batch, steps = input.shape[:2] - - if hidden1_in is None and hidden0_in is None: - shape = (batch, steps) if not squeeze1 else (batch,) - hidden0_in, hidden1_in = [ - torch.zeros( - *shape, - self.lstm.num_layers, - self.lstm.hidden_size, - device=input.device, - dtype=input.dtype, - ) - for _ in range(2) - ] - elif hidden1_in is None or hidden0_in is None: - raise RuntimeError( - f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" - ) - elif squeeze0: - hidden0_in = hidden0_in.unsqueeze(0) - hidden1_in = hidden1_in.unsqueeze(0) - - # we only need the first hidden state - if not squeeze1: - _hidden0_in = hidden0_in[:, 0] - _hidden1_in = hidden1_in[:, 0] - else: - _hidden0_in = hidden0_in - _hidden1_in = hidden1_in - hidden = ( - _hidden0_in.transpose(-3, -2).contiguous(), - _hidden1_in.transpose(-3, -2).contiguous(), - ) - - y0, hidden = self.lstm(input, hidden) - # dim 0 in hidden is num_layers, but that will conflict with tensordict - hidden = tuple(_h.transpose(0, 1) for _h in hidden) - y = self.linear(y0) - - out = [y, hidden0_in, hidden1_in, *hidden] - if squeeze1: - # squeezes time - out[0] = out[0].squeeze(1) - if not squeeze1: - # we pad the hidden states with zero to make tensordict happy - for i in range(3, 5): - out[i] = torch.stack( - [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] - + [out[i]], - 1, - ) - if squeeze0: - out = [_out.squeeze(0) for _out in out] - return tuple(out) - - def forward( - self, - input: torch.Tensor, - hidden0_in: torch.Tensor | None = None, - hidden1_in: torch.Tensor | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - input = self.mlp(input) - return self._lstm(input, hidden0_in, hidden1_in) - - class OnlineDTActor(nn.Module): """Online Decision Transformer Actor class. diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 490d1fcb5ad..8561b026f3c 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -443,7 +443,7 @@ class QValueModule(TensorDictModuleBase): def __init__( self, - action_space: Optional[str], + action_space: Optional[str] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, @@ -452,11 +452,7 @@ def __init__( safe: bool = False, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, - ) + raise TypeError("Using specs in action_space is deprecated") action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.var_nums = var_nums @@ -929,11 +925,7 @@ def __init__( out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, - ) + raise RuntimeError("Using specs in action_space is deprecated") action_space, _ = _process_action_space_spec(action_space, None) self.qvalue_model = DistributionalQValueModule( action_space=action_space, @@ -1196,11 +1188,7 @@ def __init__( make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, - ) + raise RuntimeError("Using specs in action_space is deprecated") action_space, spec = _process_action_space_spec(action_space, spec) self.action_space = action_space self.action_value_key = action_value_key diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 8dd621c98b2..7ac5d9873e5 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,7 +9,6 @@ import inspect import re import warnings -from numbers import Number from typing import Iterable, List, Optional, Type, Union import torch @@ -503,19 +502,8 @@ class DistributionalDQNnet(TensorDictModuleBase): "instead." ) - def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): + def __init__(self, *, in_keys=None, out_keys=None): super().__init__() - if DQNet is not None: - warnings.warn( - f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.", - category=DeprecationWarning, - ) - if not ( - not isinstance(DQNet.out_features, Number) - and len(DQNet.out_features) > 1 - ): - raise RuntimeError(self._wrong_out_feature_dims_error) - self.dqn = DQNet if in_keys is None: in_keys = ["action_value"] if out_keys is None: @@ -527,8 +515,6 @@ def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): def forward(self, tensordict): for in_key, out_key in zip(self.in_keys, self.out_keys): q_values = tensordict.get(in_key) - if self.dqn is not None: - q_values = self.dqn(q_values) if q_values.ndimension() < 2: raise RuntimeError( self._wrong_out_feature_dims_error.format(q_values.shape) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6487ad0597a..dd5c162f8b0 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -328,51 +327,6 @@ def __init__( def functional(self): return self._functional - @property - def actor(self): - warnings.warn( - f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network - - @property - def critic(self): - warnings.warn( - f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network - - @property - def actor_params(self): - warnings.warn( - f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network_params - - @property - def critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network_params - - @property - def target_critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.target_critic_network_params - @property def in_keys(self): keys = [ diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 6b6fd391560..cfe8b793454 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -217,12 +217,6 @@ def convert_to_functional( will carry gradients as expected. """ - if kwargs.pop("funs_to_decorate", None) is not None: - warnings.warn( - "funs_to_decorate is without effect with the new objective API. This " - "warning will be replaced by an error in v0.4.0.", - category=DeprecationWarning, - ) if kwargs: raise TypeError(f"Unrecognised keyword arguments {list(kwargs.keys())}") # To make it robust to device casting, we must register list of diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 505f07f55c0..3e219c9b72e 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -173,23 +173,13 @@ def __init__( value_network: Union[QValueActor, nn.Module], *, loss_function: Optional[str] = "l2", - delay_value: bool = None, + delay_value: bool = True, double_dqn: bool = False, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, reduction: str = None, ) -> None: - if delay_value is None: - warnings.warn( - f"You did not provide a delay_value argument for {type(self)}. " - "Currently (v0.3) the default for delay_value is `False` but as of " - "v0.4 it will be `True`. Make sure to adapt your code depending " - "on your preferred configuration. " - "To remove this warning, indicate the value of delay_value in your " - "script." - ) - delay_value = False if reduction is None: reduction = "mean" super().__init__() @@ -449,20 +439,10 @@ def __init__( value_network: Union[DistributionalQValueActor, nn.Module], *, gamma: float, - delay_value: bool = None, + delay_value: bool = True, priority_key: str = None, reduction: str = None, ): - if delay_value is None: - warnings.warn( - f"You did not provide a delay_value argument for {type(self)}. " - "Currently (v0.3) the default for delay_value is `False` but as of " - "v0.4 it will be `True`. Make sure to adapt your code depending " - "on your preferred configuration. " - "To remove this warning, indicate the value of delay_value in your " - "script." - ) - delay_value = False if reduction is None: reduction = "mean" super().__init__() diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index fcfcba49ca1..f3994abd1b2 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -189,21 +189,11 @@ def __init__( mixer_network: Union[TensorDictModule, nn.Module], *, loss_function: Optional[str] = "l2", - delay_value: bool = None, + delay_value: bool = True, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, ) -> None: - if delay_value is None: - warnings.warn( - f"You did not provide a delay_value argument for {type(self)}. " - "Currently (v0.3) the default for delay_value is `False` but as of " - "v0.4 it will be `True`. Make sure to adapt your code depending " - "on your preferred configuration. " - "To remove this warning, indicate the value of delay_value in your " - "script." - ) - delay_value = False super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index a26b90462c6..7264d5d6cbe 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -7,7 +7,6 @@ import contextlib import math -import warnings from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -383,51 +382,6 @@ def __init__( def functional(self): return self._functional - @property - def actor(self): - warnings.warn( - f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network - - @property - def critic(self): - warnings.warn( - f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network - - @property - def actor_params(self): - warnings.warn( - f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network_params - - @property - def critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network_params - - @property - def target_critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.target_critic_network_params - def _set_in_keys(self): keys = [ self.tensor_keys.action, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 96f15e8ab69..aa931b97c13 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import warnings from copy import deepcopy from dataclasses import dataclass @@ -317,51 +316,6 @@ def __init__( def functional(self): return self._functional - @property - def actor(self): - warnings.warn( - f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network - - @property - def critic(self): - warnings.warn( - f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network - - @property - def actor_params(self): - warnings.warn( - f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.actor_network_params - - @property - def critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.critic_network_params - - @property - def target_critic_params(self): - warnings.warn( - f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " - "link will be removed in v0.4.", - category=DeprecationWarning, - ) - return self.target_critic_network_params - def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 082c0ae9e9a..d3ad8d93ca4 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,7 +12,6 @@ import torch -from tensordict import MemmapTensor __all__ = [ "generalized_advantage_estimate", @@ -59,10 +58,7 @@ def transposed_fun(*args, **kwargs): time_dim = kwargs.pop("time_dim", -2) def transpose_tensor(tensor): - if ( - not isinstance(tensor, (torch.Tensor, MemmapTensor)) - or tensor.numel() <= 1 - ): + if not isinstance(tensor, torch.Tensor) or tensor.numel() <= 1: return tensor, False if time_dim >= 0: timedim = time_dim - tensor.ndim diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index 2f7e65a4069..b09becdc15a 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -16,7 +16,7 @@ transformed_env_constructor, ) from .logger import LoggerConfig -from .losses import make_dqn_loss, make_redq_loss, make_target_updater -from .models import make_dqn_actor, make_dreamer, make_redq_model +from .losses import make_dqn_loss, make_target_updater +from .models import make_dqn_actor, make_dreamer from .replay_buffer import make_replay_buffer from .trainers import make_trainer diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index a949bea6718..152d7e2891f 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -3,14 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple -from torchrl.modules import ActorCriticOperator, ActorValueOperator from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate from torchrl.objectives.common import LossModule -from torchrl.objectives.deprecated import REDQLoss_deprecated from torchrl.objectives.utils import TargetNetUpdater @@ -38,46 +35,6 @@ def make_target_updater( return target_net_updater -def make_redq_loss( - model, cfg -) -> Tuple[REDQLoss_deprecated, Optional[TargetNetUpdater]]: - """Builds the REDQ loss module.""" - warnings.warn( - "This helper function will be deprecated in v0.4. Consider using the local helper in the REDQ example.", - category=DeprecationWarning, - ) - loss_kwargs = {} - if hasattr(cfg, "distributional") and cfg.distributional: - raise NotImplementedError - else: - loss_kwargs.update({"loss_function": cfg.loss_function}) - loss_kwargs.update({"delay_qvalue": cfg.loss == "double"}) - loss_class = REDQLoss_deprecated - if isinstance(model, ActorValueOperator): - actor_model = model.get_policy_operator() - qvalue_model = model.get_value_operator() - elif isinstance(model, ActorCriticOperator): - raise RuntimeError( - "Although REDQ Q-value depends upon selected actions, using the" - "ActorCriticOperator will lead to resampling of the actions when" - "computing the Q-value loss, which we don't want. Please use the" - "ActorValueOperator instead." - ) - else: - actor_model, qvalue_model = model - - loss_module = loss_class( - actor_network=actor_model, - qvalue_network=qvalue_model, - num_qvalue_nets=cfg.num_q_values, - gSDE=cfg.gSDE, - **loss_kwargs, - ) - loss_module.make_value_estimator(gamma=cfg.gamma) - target_net_updater = make_target_updater(cfg, loss_module) - return loss_module, target_net_updater - - def make_dqn_loss(model, cfg) -> Tuple[DQNLoss, Optional[TargetNetUpdater]]: """Builds the DQN loss module.""" loss_kwargs = {} diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 05f566674f2..0a3cea40b36 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -3,16 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import itertools -import warnings from dataclasses import dataclass -from typing import Optional, Sequence import torch - from tensordict import set_lazy_legacy from tensordict.nn import InteractionType -from torch import distributions as d, nn - +from torch import nn from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, @@ -25,7 +21,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( NoisyLinear, - NormalParamWrapper, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, @@ -37,8 +32,6 @@ TanhDelta, TanhNormal, ) -from torchrl.modules.distributions.continuous import SafeTanhTransform -from torchrl.modules.models.exploration import LazygSDEModule from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -47,19 +40,12 @@ RSSMPrior, RSSMRollout, ) -from torchrl.modules.models.models import ( - DdpgCnnActor, - DdpgCnnQNet, - DuelingCnnDQNet, - DuelingMlpDQNet, - MLP, -) +from torchrl.modules.models.models import DuelingCnnDQNet, DuelingMlpDQNet, MLP from torchrl.modules.tensordict_module import ( Actor, DistributionalQValueActor, QValueActor, ) -from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator from torchrl.modules.tensordict_module.world_models import WorldModelWrapper from torchrl.trainers.helpers import transformed_env_constructor @@ -210,248 +196,6 @@ def make_dqn_actor( return model -def make_redq_model( - proof_environment: EnvBase, - cfg: "DictConfig", # noqa: F821 - device: DEVICE_TYPING = "cpu", - in_keys: Optional[Sequence[str]] = None, - actor_net_kwargs=None, - qvalue_net_kwargs=None, - observation_key=None, - **kwargs, -) -> nn.ModuleList: - """Actor and Q-value model constructor helper function for REDQ. - - Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd. - Other configurations can easily be implemented by modifying this function at will. - A single instance of the Q-value model is returned. It will be multiplicated by the loss function. - - Args: - proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec - cfg (DictConfig): contains arguments of the REDQ script - device (torch.device, optional): device on which the model must be cast. Default is "cpu". - in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of - `'observation_vector'` or `'pixels'`. If none is provided, one of these two keys is chosen - based on the `cfg.from_pixels` argument. - actor_net_kwargs (dict, optional): kwargs of the actor MLP. - qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP. - - Returns: - A nn.ModuleList containing the actor, qvalue operator(s) and the value operator. - - Examples: - >>> from torchrl.trainers.helpers.envs import parser_env_args - >>> from torchrl.trainers.helpers.models import make_redq_model, parser_model_args_continuous - >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose - >>> import hydra - >>> from hydra.core.config_store import ConfigStore - >>> import dataclasses - >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v4"), Compose(DoubleToFloat(["observation"]), - ... CatTensors(["observation"], "observation_vector"))) - >>> device = torch.device("cpu") - >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in - ... (RedqModelConfig, EnvConfig) - ... for config_field in dataclasses.fields(config_cls)] - >>> Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields) - >>> cs = ConfigStore.instance() - >>> cs.store(name="config", node=Config) - >>> with initialize(config_path=None): - >>> cfg = compose(config_name="config") - >>> model = make_redq_model( - ... proof_environment, - ... device=device, - ... cfg=cfg, - ... ) - >>> actor, qvalue = model - >>> td = proof_environment.reset() - >>> print(actor(td)) - TensorDict( - fields={ - done: Tensor(torch.Size([1]), dtype=torch.bool), - observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), - loc: Tensor(torch.Size([6]), dtype=torch.float32), - scale: Tensor(torch.Size([6]), dtype=torch.float32), - action: Tensor(torch.Size([6]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([1]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - >>> print(qvalue(td.clone())) - TensorDict( - fields={ - done: Tensor(torch.Size([1]), dtype=torch.bool), - observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), - loc: Tensor(torch.Size([6]), dtype=torch.float32), - scale: Tensor(torch.Size([6]), dtype=torch.float32), - action: Tensor(torch.Size([6]), dtype=torch.float32), - sample_log_prob: Tensor(torch.Size([1]), dtype=torch.float32), - state_action_value: Tensor(torch.Size([1]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - - """ - warnings.warn( - "This helper function will be deprecated in v0.4. Consider using the local helper in the REDQ example.", - category=DeprecationWarning, - ) - tanh_loc = cfg.tanh_loc - default_policy_scale = cfg.default_policy_scale - gSDE = cfg.gSDE - - action_spec = proof_environment.action_spec - # obs_spec = proof_environment.observation_spec - # if observation_key is not None: - # obs_spec = obs_spec[observation_key] - # else: - # obs_spec_values = list(obs_spec.values()) - # if len(obs_spec_values) > 1: - # raise RuntimeError( - # "There is more than one observation in the spec, REDQ helper " - # "cannot infer automatically which to pick. " - # "Please indicate which key to read via the `observation_key` " - # "keyword in this helper." - # ) - # else: - # obs_spec = obs_spec_values[0] - - if actor_net_kwargs is None: - actor_net_kwargs = {} - if qvalue_net_kwargs is None: - qvalue_net_kwargs = {} - - linear_layer_class = torch.nn.Linear if not cfg.noisy else NoisyLinear - - out_features_actor = (2 - gSDE) * action_spec.shape[-1] - if cfg.from_pixels: - if in_keys is None: - in_keys_actor = ["pixels"] - else: - in_keys_actor = in_keys - actor_net_kwargs_default = { - "mlp_net_kwargs": { - "layer_class": linear_layer_class, - "activation_class": ACTIVATIONS[cfg.activation], - }, - "conv_net_kwargs": {"activation_class": ACTIVATIONS[cfg.activation]}, - } - actor_net_kwargs_default.update(actor_net_kwargs) - actor_net = DdpgCnnActor(out_features_actor, **actor_net_kwargs_default) - gSDE_state_key = "hidden" - out_keys_actor = ["param", "hidden"] - - value_net_default_kwargs = { - "mlp_net_kwargs": { - "layer_class": linear_layer_class, - "activation_class": ACTIVATIONS[cfg.activation], - }, - "conv_net_kwargs": {"activation_class": ACTIVATIONS[cfg.activation]}, - } - value_net_default_kwargs.update(qvalue_net_kwargs) - - in_keys_qvalue = ["pixels", "action"] - qvalue_net = DdpgCnnQNet(**value_net_default_kwargs) - else: - if in_keys is None: - in_keys_actor = ["observation_vector"] - else: - in_keys_actor = in_keys - - actor_net_kwargs_default = { - "num_cells": [cfg.actor_cells, cfg.actor_cells], - "out_features": out_features_actor, - "activation_class": ACTIVATIONS[cfg.activation], - } - actor_net_kwargs_default.update(actor_net_kwargs) - actor_net = MLP(**actor_net_kwargs_default) - out_keys_actor = ["param"] - gSDE_state_key = in_keys_actor[0] - - qvalue_net_kwargs_default = { - "num_cells": [cfg.qvalue_cells, cfg.qvalue_cells], - "out_features": 1, - "activation_class": ACTIVATIONS[cfg.activation], - } - qvalue_net_kwargs_default.update(qvalue_net_kwargs) - qvalue_net = MLP( - **qvalue_net_kwargs_default, - ) - in_keys_qvalue = in_keys_actor + ["action"] - - dist_class = TanhNormal - dist_kwargs = { - "min": action_spec.space.low, - "max": action_spec.space.high, - "tanh_loc": tanh_loc, - } - - if not gSDE: - actor_net = NormalParamWrapper( - actor_net, - scale_mapping=f"biased_softplus_{default_policy_scale}", - scale_lb=cfg.scale_lb, - ) - actor_module = SafeModule( - actor_net, - in_keys=in_keys_actor, - out_keys=["loc", "scale"] + out_keys_actor[1:], - ) - - else: - actor_module = SafeModule( - actor_net, - in_keys=in_keys_actor, - out_keys=["action"] + out_keys_actor[1:], # will be overwritten - ) - - if action_spec.domain == "continuous": - min = action_spec.space.low - max = action_spec.space.high - transform = SafeTanhTransform() - if (min != -1).any() or (max != 1).any(): - transform = d.ComposeTransform( - transform, - d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), - ) - else: - raise RuntimeError("cannot use gSDE with discrete actions") - - actor_module = SafeSequential( - actor_module, - SafeModule( - LazygSDEModule(transform=transform), - in_keys=["action", gSDE_state_key, "_eps_gSDE"], - out_keys=["loc", "scale", "action", "_eps_gSDE"], - ), - ) - - actor = ProbabilisticActor( - spec=action_spec, - in_keys=["loc", "scale"], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs, - default_interaction_type=InteractionType.RANDOM, - return_log_prob=True, - ) - qvalue = ValueOperator( - in_keys=in_keys_qvalue, - module=qvalue_net, - ) - model = nn.ModuleList([actor, qvalue]).to(device) - - # init nets - with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = proof_environment.fake_tensordict() - td = td.unsqueeze(-1) - td = td.to(device) - for net in model: - net(td) - del td - return model - - @set_lazy_legacy(False) def make_dreamer( cfg: "DictConfig", # noqa: F821 diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 526b3c967e8..ccd9bb23bb3 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -662,23 +662,13 @@ def __init__( batch_size: Optional[int] = None, memmap: bool = False, device: DEVICE_TYPING = "cpu", - flatten_tensordicts: bool = None, + flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None, ) -> None: self.replay_buffer = replay_buffer self.batch_size = batch_size self.memmap = memmap self.device = device - if flatten_tensordicts is None: - warnings.warn( - "flatten_tensordicts default value has now changed " - "to False for a faster execution. Make sure your " - "code is robust to this change. To silence this warning, " - "pass flatten_tensordicts= in your code. " - "This warning will be removed in v0.4.", - category=DeprecationWarning, - ) - flatten_tensordicts = True self.flatten_tensordicts = flatten_tensordicts self.max_dims = max_dims