From 863121a275c74598cd228423cbea683e2b41806d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 19 Nov 2024 15:53:59 +0000 Subject: [PATCH] [BugFix] Fix failing tests ghstack-source-id: a43a2e3dbf76cd63c57ae00028df04b41a4e2f2b Pull Request resolved: https://github.com/pytorch/rl/pull/2582 --- .github/workflows/docs.yml | 2 + sota-implementations/ddpg/utils.py | 6 +- sota-implementations/dreamer/dreamer_utils.py | 1 + .../multiagent/maddpg_iddpg.py | 1 + sota-implementations/redq/config.yaml | 2 +- sota-implementations/redq/redq.py | 3 +- sota-implementations/redq/utils.py | 68 +++++---- sota-implementations/td3/utils.py | 3 +- sota-implementations/td3_bc/utils.py | 3 +- test/_utils_internal.py | 4 +- test/test_actors.py | 2 +- test/test_collector.py | 23 +++ test/test_exploration.py | 31 ++-- test/test_transforms.py | 3 + torchrl/collectors/collectors.py | 14 +- torchrl/envs/__init__.py | 2 +- torchrl/modules/distributions/continuous.py | 9 ++ .../modules/tensordict_module/exploration.py | 134 ++++++++++++------ torchrl/objectives/__init__.py | 1 + torchrl/objectives/crossq.py | 53 +++++-- 20 files changed, 248 insertions(+), 117 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index aea9d7e2c40..caa955c3bf0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -119,6 +119,8 @@ jobs: REF_TYPE=${{ github.ref_type }} REF_NAME=${{ github.ref_name }} + apt-get update + apt-get install rsync -y if [[ "${REF_TYPE}" == branch ]]; then if [[ "${REF_NAME}" == main ]]; then diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 338081a7e8d..9495fd038f2 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -234,7 +234,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): OrnsteinUhlenbeckProcessModule( spec=action_spec, annealing_num_steps=1_000_000, - ).to(device), + device=device, + ), ) elif cfg.network.noise_type == "gaussian": actor_model_explore = TensorDictSequential( @@ -245,7 +246,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): sigma_init=1.0, mean=0.0, std=0.1, - ).to(device), + device=device, + ), ) else: raise NotImplementedError diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 849d8c813b6..41ea170ac76 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -275,6 +275,7 @@ def make_dreamer( annealing_num_steps=1, mean=0.0, std=cfg.networks.exploration_noise, + device=device, ), ) diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index aad1df14fff..6199a888344 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -108,6 +108,7 @@ def train(cfg: "DictConfig"): # noqa: F821 spec=env.unbatched_action_spec, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, + device=cfg.train.device, ), ) diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml index 818f3386fda..e26b81aa459 100644 --- a/sota-implementations/redq/config.yaml +++ b/sota-implementations/redq/config.yaml @@ -30,7 +30,7 @@ collector: async_collection: 1 frames_per_batch: 1024 total_frames: 1_000_000 - device: cpu + device: env_per_collector: 1 init_random_frames: 50_000 multi_step: 1 diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index 865533aee2f..0732bf5f3b4 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -119,7 +119,8 @@ def main(cfg: "DictConfig"): # noqa: F821 annealing_num_steps=cfg.exploration.annealing_frames, sigma=cfg.exploration.ou_sigma, theta=cfg.exploration.ou_theta, - ).to(device), + device=device, + ), ) if device == torch.device("cpu"): # mostly for debugging diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 8312d359366..2823858af60 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -21,55 +21,59 @@ from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.collectors import DataCollectorBase -from torchrl.data import ReplayBuffer, TensorDictReplayBuffer -from torchrl.data.postprocs import MultiStep -from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data import ( + LazyMemmapStorage, + MultiStep, + PrioritizedSampler, + RandomSampler, + ReplayBuffer, + TensorDictReplayBuffer, +) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs import ParallelEnv -from torchrl.envs.common import EnvBase -from torchrl.envs.env_creator import env_creator, EnvCreator -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import ( +from torchrl.envs import ( CatFrames, CatTensors, CenterCrop, Compose, + DMControlEnv, DoubleToFloat, + env_creator, + EnvBase, + EnvCreator, + FlattenObservation, GrayScale, + gSDENoise, + GymEnv, + InitTracker, NoopResetEnv, ObservationNorm, + ParallelEnv, Resize, RewardScaling, + StepCounter, ToTensorImage, TransformedEnv, VecNorm, ) -from torchrl.envs.transforms.transforms import ( - FlattenObservation, - gSDENoise, - InitTracker, - StepCounter, -) from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( ActorCriticOperator, ActorValueOperator, + DdpgCnnActor, + DdpgCnnQNet, + MLP, NoisyLinear, NormalParamExtractor, + ProbabilisticActor, SafeModule, SafeSequential, + TanhNormal, + ValueOperator, ) -from torchrl.modules.distributions import TanhNormal from torchrl.modules.distributions.continuous import SafeTanhTransform from torchrl.modules.models.exploration import LazygSDEModule -from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP -from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator -from torchrl.objectives import HardUpdate, SoftUpdate -from torchrl.objectives.common import LossModule +from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater from torchrl.objectives.deprecated import REDQLoss_deprecated -from torchrl.objectives.utils import TargetNetUpdater from torchrl.record.loggers import Logger from torchrl.record.recorder import VideoRecorder from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector @@ -518,7 +522,7 @@ def make_redq_model( actor_module = SafeSequential( actor_module, SafeModule( - LazygSDEModule(transform=transform), + LazygSDEModule(transform=transform, device=device), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], ), @@ -606,7 +610,9 @@ def make_transformed_env(**kwargs) -> TransformedEnv: categorical_action_encoding = cfg.env.categorical_action_encoding if custom_env is None and custom_env_maker is None: - if isinstance(cfg.collector.device, str): + if cfg.collector.device in ("", None): + device = "cpu" if not torch.cuda.is_available() else "cuda:0" + elif isinstance(cfg.collector.device, str): device = cfg.collector.device elif isinstance(cfg.collector.device, Sequence): device = cfg.collector.device[0] @@ -1000,11 +1006,14 @@ def make_collector_offpolicy( env_kwargs.update(make_env_kwargs) elif make_env_kwargs is not None: env_kwargs = make_env_kwargs - cfg.collector.device = ( - cfg.collector.device - if len(cfg.collector.device) > 1 - else cfg.collector.device[0] - ) + if cfg.collector.device in ("", None): + cfg.collector.device = "cpu" if not torch.cuda.is_available() else "cuda:0" + else: + cfg.collector.device = ( + cfg.collector.device + if len(cfg.collector.device) > 1 + else cfg.collector.device[0] + ) collector_helper_kwargs = { "env_fns": make_env, "env_kwargs": env_kwargs, @@ -1017,7 +1026,6 @@ def make_collector_offpolicy( # we already took care of building the make_parallel_env function "num_collectors": -cfg.num_workers // -cfg.collector.env_per_collector, "device": cfg.collector.device, - "storing_device": cfg.collector.device, "init_random_frames": cfg.collector.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 60a4d046355..665c2e0c674 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -242,7 +242,8 @@ def make_td3_agent(cfg, train_env, eval_env, device): mean=0, std=0.1, spec=action_spec, - ).to(device), + device=device, + ), ) return model, actor_model_explore diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index 3dcbd45d30c..582afaaac04 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -183,7 +183,8 @@ def make_td3_agent(cfg, train_env, device): mean=0, std=0.1, spec=action_spec, - ).to(device), + device=device, + ), ) return model, actor_model_explore diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 683a0d60182..dea0d136844 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -167,11 +167,11 @@ def get_available_devices(): def get_default_devices(): num_cuda = torch.cuda.device_count() if num_cuda == 0: + if torch.mps.is_available(): + return [torch.device("mps:0")] return [torch.device("cpu")] elif num_cuda == 1: return [torch.device("cuda:0")] - elif torch.mps.is_available(): - return [torch.device("mps:0")] else: # then run on all devices return get_available_devices() diff --git a/test/test_actors.py b/test/test_actors.py index c50bf7b9e62..ae1436b5afc 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -47,7 +47,7 @@ ("data", "sample_log_prob"), ], ) -def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3): +def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=1): env = NestedCountingEnv(nested_dim=nested_dim) action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) policy_module = TensorDictModule( diff --git a/test/test_collector.py b/test/test_collector.py index 1309254ce2d..38191a46eaa 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3172,6 +3172,29 @@ def make_and_test_policy( ) +@pytest.mark.parametrize( + "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] +) +def test_no_stopiteration(ctype): + # Tests that there is no StopIteration raised and that the length of the collector is properly set + if ctype is SyncDataCollector: + envs = SerialEnv(16, CountingEnv) + else: + envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] + + collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) + try: + c_iter = iter(collector) + assert len(collector) == 2 + for i in range(len(collector)): # noqa: B007 + c = next(c_iter) + assert c is not None + assert i == 1 + finally: + collector.shutdown() + del collector + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_exploration.py b/test/test_exploration.py index b69b743e48c..0020c9dfc6b 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -241,8 +241,8 @@ def test_ou( self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0 ): torch.manual_seed(seed) - net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( - device + net = nn.Sequential( + nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor() ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,)) @@ -252,13 +252,13 @@ def test_ou( in_keys=["loc", "scale"], distribution_class=TanhNormal, default_interaction_type=InteractionType.RANDOM, - ).to(device) + ) if interface == "module": - ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device) + ou = OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device) exploratory_policy = TensorDictSequential(policy, ou) else: - exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device) ou = exploratory_policy tensordict = TensorDict( @@ -338,10 +338,10 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0 if interface == "module": exploratory_policy = TensorDictSequential( - policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device) + policy, OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device) ) else: - exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device) exploratory_policy(env.reset()) collector = SyncDataCollector( create_env_fn=env, @@ -456,10 +456,10 @@ def test_additivegaussian_sd( device=device, ) if interface == "module": - exploratory_policy = AdditiveGaussianModule(action_spec).to(device) + exploratory_policy = AdditiveGaussianModule(action_spec, device=device) else: - net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( - device + net = nn.Sequential( + nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor() ) module = SafeModule( net, @@ -473,10 +473,10 @@ def test_additivegaussian_sd( in_keys=["loc", "scale"], distribution_class=TanhNormal, default_interaction_type=InteractionType.RANDOM, - ).to(device) + ) given_spec = action_spec if spec_origin == "spec" else None - exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to( - device + exploratory_policy = AdditiveGaussianWrapper( + policy, spec=given_spec, device=device ) if spec_origin is not None: sigma_init = ( @@ -727,10 +727,7 @@ def test_gsde( @pytest.mark.parametrize("std", [1, 2]) @pytest.mark.parametrize("sigma_init", [None, 1.5, 3]) @pytest.mark.parametrize("learn_sigma", [False, True]) -@pytest.mark.parametrize( - "device", - [torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")], -) +@pytest.mark.parametrize("device", get_default_devices()) def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_sigma): torch.manual_seed(0) state = torch.randn(10000, *state_dim, device=device) * std + mean diff --git a/test/test_transforms.py b/test/test_transforms.py index d4b9918f062..56a39218f5f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2076,7 +2076,10 @@ def test_transform_rb(self, rbclass): ): td = rb.sample(10) + @retry(AssertionError, tries=10, delay=0) def test_collector_match(self): + torch.manual_seed(0) + # The counter in the collector should match the one from the transform t = TrajCounter() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 319722a552e..16eb5904b84 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -146,6 +146,7 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): _iterator = None total_frames: int + requested_frames_per_batch: int frames_per_batch: int trust_policy: bool compiled_policy: bool @@ -161,6 +162,8 @@ def _get_policy_and_device( ) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. + We want to copy the policy and then move the data there, not call policy.to(device). + Args: policy (TensorDictModule, optional): a policy to be used observation_spec (TensorSpec, optional): spec of the observations @@ -218,7 +221,7 @@ def map_weight( weight = weight.data if weight.device != policy_device: weight = weight.to(policy_device) - elif weight.device.type in ("cpu", "mps"): + elif weight.device.type in ("cpu",): weight = weight.share_memory_() if is_param: weight = Parameter(weight, requires_grad=False) @@ -232,7 +235,7 @@ def map_weight( policy = deepcopy(policy) param_and_buf.apply( - functools.partial(map_weight), + map_weight, filter_empty=False, ).to_module(policy) return policy, get_original_weights @@ -305,7 +308,7 @@ def __class_getitem__(self, index): def __len__(self) -> int: if self.total_frames > 0: - return -(self.total_frames // -self.frames_per_batch) + return -(self.total_frames // -self.requested_frames_per_batch) raise RuntimeError("Non-terminating collectors do not have a length") @@ -700,7 +703,7 @@ def __init__( remainder = total_frames % frames_per_batch if remainder != 0 and RL_WARNINGS: warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})." + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " f"This means {frames_per_batch - remainder} additional frames will be collected." "To silence this message, set the environment variable RL_WARNINGS to False." ) @@ -737,8 +740,8 @@ def __init__( f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " "To silence this message, set the environment variable RL_WARNINGS to False." ) - self.requested_frames_per_batch = int(frames_per_batch) self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.requested_frames_per_batch = self.frames_per_batch * self.n_env self.exploration_type = ( exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE ) @@ -1653,6 +1656,7 @@ def __init__( self._get_weights_fn_dict[policy_device] = get_weights_fn self.policy = policy + remainder = 0 if total_frames is None or total_frames < 0: total_frames = float("inf") else: diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 56f7a5a3332..4cfb00cc307 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -6,7 +6,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict from .custom import PendulumEnv, TicTacToeEnv -from .env_creator import EnvCreator, get_env_metadata +from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( BraxEnv, diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 6c200c15ee4..f32a3b0c6fa 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -611,6 +611,15 @@ def __init__( event_shape = param.shape[-1:] super().__init__(batch_shape=batch_shape, event_shape=event_shape) + def expand(self, batch_shape: torch.Size, _instance=None): + if self.batch_shape != tuple(batch_shape): + return type(self)( + self.param.expand((*batch_shape, *self.event_shape)), + atol=self.atol, + rtol=self.rtol, + ) + return self + def update(self, param): self.param = param diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index df947236970..9acfae1aa21 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from typing import Optional, Union @@ -15,6 +17,7 @@ TensorDictModuleWrapper, ) from tensordict.utils import expand_as_right, expand_right, NestedKey +from torch import nn from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.envs.utils import exploration_type, ExplorationType @@ -232,6 +235,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): is set to False but the spec is passed, the projection will still happen. Default is True. + device (torch.device, optional): the device where the buffers have to be stored. .. note:: Once an environment has been wrapped in :class:`AdditiveGaussianWrapper`, it is @@ -255,6 +259,7 @@ def __init__( action_key: Optional[NestedKey] = "action", spec: Optional[TensorSpec] = None, safe: Optional[bool] = True, + device: torch.device | None = None, ): warnings.warn( "AdditiveGaussianWrapper is deprecated and will be removed " @@ -262,15 +267,22 @@ def __init__( "instead.", category=DeprecationWarning, ) + if device is None and hasattr(policy, "parameters"): + for p in policy.parameters(): + device = p.device + break + super().__init__(policy) if sigma_end > sigma_init: raise RuntimeError("sigma should decrease over time or be constant") - self.register_buffer("sigma_init", torch.tensor([sigma_init])) - self.register_buffer("sigma_end", torch.tensor([sigma_end])) + self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device)) + self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("mean", torch.tensor([mean])) - self.register_buffer("std", torch.tensor([std])) - self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) + self.register_buffer("mean", torch.tensor([mean], device=device)) + self.register_buffer("std", torch.tensor([std], device=device)) + self.register_buffer( + "sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device) + ) self.action_key = action_key self.out_keys = list(self.td_module.out_keys) if action_key not in self.out_keys: @@ -312,19 +324,23 @@ def step(self, frames: int = 1) -> None: for _ in range(frames): self.sigma.data.copy_( torch.maximum( - self.sigma_end( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ), - ) + self.sigma_end, + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps, + ), ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: sigma = self.sigma - noise = torch.normal( - mean=torch.ones(action.shape) * self.mean, - std=torch.ones(action.shape) * self.std, - ).to(action.device) + mean = self.mean.expand(action.shape) + std = self.std.expand(action.shape) + if not mean.dtype.is_floating_point: + mean = mean.to(torch.get_default_dtype()) + if not std.dtype.is_floating_point: + std = std.to(torch.get_default_dtype()) + noise = torch.normal(mean=mean, std=std) + if noise.device != action.device: + noise = noise.to(action.device) action = action + noise * sigma spec = self.spec spec = spec[self.action_key] @@ -372,6 +388,7 @@ class AdditiveGaussianModule(TensorDictModuleBase): safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space given the :obj:`TensorSpec.project` heuristic. default: True + device (torch.device, optional): the device where the buffers have to be stored. .. note:: It is @@ -394,6 +411,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", safe: bool = True, + device: torch.device | None = None, ): if not isinstance(sigma_init, float): warnings.warn("eps_init should be a float.") @@ -405,12 +423,14 @@ def __init__( super().__init__() - self.register_buffer("sigma_init", torch.tensor([sigma_init])) - self.register_buffer("sigma_end", torch.tensor([sigma_end])) + self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device)) + self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("mean", torch.tensor([mean])) - self.register_buffer("std", torch.tensor([std])) - self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) + self.register_buffer("mean", torch.tensor([mean], device=device)) + self.register_buffer("std", torch.tensor([std], device=device)) + self.register_buffer( + "sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device) + ) if spec is not None: if not isinstance(spec, Composite) and len(self.out_keys) >= 1: @@ -448,10 +468,15 @@ def step(self, frames: int = 1) -> None: def _add_noise(self, action: torch.Tensor) -> torch.Tensor: sigma = self.sigma - noise = torch.normal( - mean=torch.ones(action.shape) * self.mean, - std=torch.ones(action.shape) * self.std, - ).to(action.device) + mean = self.mean.expand(action.shape) + std = self.std.expand(action.shape) + if not mean.dtype.is_floating_point: + mean = mean.to(torch.get_default_dtype()) + if not std.dtype.is_floating_point: + std = std.to(torch.get_default_dtype()) + noise = torch.normal(mean=mean, std=std) + if noise.device != action.device: + noise = noise.to(action.device) action = action + noise * sigma spec = self.spec[self.action_key] action = spec.project(action) @@ -530,6 +555,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space given the :obj:`TensorSpec.project` heuristic. default: True + device (torch.device, optional): the device where the buffers have to be stored. Examples: >>> import torch @@ -573,6 +599,7 @@ def __init__( spec: TensorSpec = None, safe: bool = True, key: Optional[NestedKey] = None, + device: torch.device | None = None, ): warnings.warn( "OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed " @@ -580,6 +607,10 @@ def __init__( "instead.", category=DeprecationWarning, ) + if device is None and hasattr(policy, "parameters"): + for p in policy.parameters(): + device = p.device + break if key is not None: action_key = key warnings.warn( @@ -595,16 +626,19 @@ def __init__( sigma_min=sigma_min, n_steps_annealing=n_steps_annealing, key=action_key, + device=device, ) - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) + self.register_buffer("eps_init", torch.tensor([eps_init], device=device)) + self.register_buffer("eps_end", torch.tensor([eps_end], device=device)) if self.eps_end > self.eps_init: raise ValueError( "eps should decrease over time or be constant, " f"got eps_init={eps_init} and eps_end={eps_end}" ) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + self.register_buffer( + "eps", torch.tensor([eps_init], dtype=torch.float32, device=device) + ) self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys self.is_init_key = is_init_key noise_key = self.ou.noise_key @@ -746,6 +780,7 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): is set to False but the spec is passed, the projection will still happen. Default is True. + device (torch.device, optional): the device where the buffers have to be stored. Examples: >>> import torch @@ -782,13 +817,14 @@ def __init__( mu: float = 0.0, sigma: float = 0.2, dt: float = 1e-2, - x0: Optional[Union[torch.Tensor, np.ndarray]] = None, - sigma_min: Optional[float] = None, + x0: torch.Tensor | np.ndarray | None = None, + sigma_min: float | None = None, n_steps_annealing: int = 1000, *, - action_key: Optional[NestedKey] = "action", - is_init_key: Optional[NestedKey] = "is_init", + action_key: NestedKey = "action", + is_init_key: NestedKey = "is_init", safe: bool = True, + device: torch.device | None = None, ): super().__init__() @@ -801,17 +837,20 @@ def __init__( sigma_min=sigma_min, n_steps_annealing=n_steps_annealing, key=action_key, + device=device, ) - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) + self.register_buffer("eps_init", torch.tensor([eps_init], device=device)) + self.register_buffer("eps_end", torch.tensor([eps_end], device=device)) if self.eps_end > self.eps_init: raise ValueError( "eps should decrease over time or be constant, " f"got eps_init={eps_init} and eps_end={eps_end}" ) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + self.register_buffer( + "eps", torch.tensor([eps_init], dtype=torch.float32, device=device) + ) self.in_keys = [self.ou.key] self.out_keys = [self.ou.key] + self.ou.out_keys @@ -883,7 +922,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab -class _OrnsteinUhlenbeckProcess: +class _OrnsteinUhlenbeckProcess(nn.Module): def __init__( self, theta: float, @@ -895,7 +934,11 @@ def __init__( n_steps_annealing: int = 1000, key: Optional[NestedKey] = "action", is_init_key: Optional[NestedKey] = "is_init", + device: torch.device | None = None, ): + super().__init__() + self.register_buffer("_empty_tensor_device", torch.zeros(0, device=device)) + self.mu = mu self.sigma = sigma @@ -917,6 +960,13 @@ def __init__( self._noise_key = "_ou_prev_noise" self._steps_key = "_ou_steps" self.out_keys = [self.noise_key, self.steps_key] + self._auto_buffer() + + def _auto_buffer(self): + for key, item in list(self.__dict__.items()): + if isinstance(item, torch.Tensor): + delattr(self, key) + self.register_buffer(key, item) @property def noise_key(self): @@ -932,12 +982,14 @@ def _make_noise_pair( tensordict: TensorDictBase, is_init: torch.Tensor, ): + device = tensordict.device + if device is None: + device = self._empty_tensor_device.device + if self.steps_key not in tensordict.keys(): - noise = torch.zeros( - tensordict.get(self.key).shape, device=tensordict.device - ) + noise = torch.zeros(tensordict.get(self.key).shape, device=device) steps = torch.zeros( - action_tensordict.batch_size, dtype=torch.long, device=tensordict.device + action_tensordict.batch_size, dtype=torch.long, device=device ) tensordict.set(self.noise_key, noise) tensordict.set(self.steps_key, steps) @@ -946,8 +998,8 @@ def _make_noise_pair( noise = tensordict.get(self.noise_key).clone() steps = tensordict.get(self.steps_key).clone() if is_init is not None: - noise = torch.masked_fill(noise, is_init, 0) - steps = torch.masked_fill(steps, is_init, 0) + noise = torch.masked_fill(noise, expand_right(is_init, noise.shape), 0) + steps = torch.masked_fill(steps, expand_right(is_init, steps.shape), 0) return noise, steps def add_sample( @@ -972,7 +1024,7 @@ def add_sample( is_init = is_init.squeeze(-1) # Squeeze dangling dim if ( action_tensordict.ndim >= is_init.ndim - ): # if is_init has less dimensions than action_tensordict we expand it + ): # if is_init has fewer dimensions than action_tensordict we expand it is_init = expand_right(is_init, action_tensordict.shape) else: is_init = is_init.sum( diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 01f993e629a..f8f5636db95 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -29,5 +29,6 @@ hold_out_params, next_state_value, SoftUpdate, + TargetNetUpdater, ValueEstimators, ) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index ca6559ac5b8..801180901a7 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -341,7 +341,7 @@ def __init__( self._make_vmap() self.reduction = reduction # init target entropy - _ = self.target_entropy + self.maybe_init_target_entropy() def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -356,22 +356,23 @@ def target_entropy_buffer(self): """ return self.target_entropy - @property - def target_entropy(self): - target_entropy = self._buffers.get("_target_entropy", None) - if target_entropy is not None: - return target_entropy + def maybe_init_target_entropy(self, fault_tolerant=True): + """Initialize the target entropy. + + Args: + fault_tolerant (bool, optional): if ``True``, returns None if the target entropy + cannot be determined. Raises an exception otherwise. Defaults to ``True``. + + """ + if "_target_entropy" in self._buffers: + return target_entropy = self._target_entropy - action_spec = self._action_spec - actor_network = self.actor_network - device = next(self.parameters()).device if target_entropy == "auto": - action_spec = ( - action_spec - if action_spec is not None - else getattr(actor_network, "spec", None) - ) + device = next(self.parameters()).device + action_spec = self.get_action_spec() if action_spec is None: + if fault_tolerant: + return raise RuntimeError( "Cannot infer the dimensionality of the action. Consider providing " "the target entropy explicitely or provide the spec of the " @@ -379,6 +380,8 @@ def target_entropy(self): ) if not isinstance(action_spec, Composite): action_spec = Composite({self.tensor_keys.action: action_spec}) + elif fault_tolerant and self.tensor_keys.action not in action_spec: + return if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -397,6 +400,28 @@ def target_entropy(self): ) return self._target_entropy + def get_action_spec(self): + action_spec = self._action_spec + actor_network = self.actor_network + action_spec = ( + action_spec + if action_spec is not None + else getattr(actor_network, "spec", None) + ) + return action_spec + + @property + def target_entropy(self): + target_entropy = self._buffers.get("_target_entropy") + if target_entropy is not None: + return target_entropy + return self.maybe_init_target_entropy(fault_tolerant=False) + + def set_keys(self, **kwargs) -> None: + out = super().set_keys(**kwargs) + self.maybe_init_target_entropy() + return out + state_dict = _delezify(LossModule.state_dict) load_state_dict = _delezify(LossModule.load_state_dict)