From 87f66e86fc51ae43dca40ac743b3e8e0987a78f1 Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Mon, 22 Jul 2024 05:27:56 -0700 Subject: [PATCH] [Refactor] Update all instances of exploration `*Wrapper` to `*Module` (#2298) Co-authored-by: Vincent Moens --- sota-implementations/ddpg/ddpg.py | 2 +- sota-implementations/ddpg/utils.py | 30 ++++++++++++------- sota-implementations/dreamer/dreamer.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 17 ++++++----- .../multiagent/maddpg_iddpg.py | 15 ++++++---- sota-implementations/redq/redq.py | 16 ++++++---- sota-implementations/td3/td3.py | 2 +- sota-implementations/td3/utils.py | 19 +++++++----- sota-implementations/td3_bc/utils.py | 19 +++++++----- test/test_collector.py | 15 +++++++--- test/test_tensordictmodules.py | 10 ++++--- tutorials/sphinx-tutorials/coding_ddpg.py | 17 ++++++----- .../sphinx-tutorials/getting-started-1.py | 4 +-- .../multiagent_competitive_ddpg.py | 19 +++++++----- 14 files changed, 113 insertions(+), 74 deletions(-) diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index a92ee6185c3..1b038d69d15 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for _, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start # Update exploration policy - exploration_policy.step(tensordict.numel()) + exploration_policy[1].step(tensordict.numel()) # Update weights of the inference policy collector.update_policy_weights_() diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 45c6da7a342..338081a7e8d 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -6,6 +6,8 @@ import torch +from tensordict.nn import TensorDictSequential + from torch import nn, optim from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer @@ -25,9 +27,9 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MLP, - OrnsteinUhlenbeckProcessWrapper, + OrnsteinUhlenbeckProcessModule, SafeModule, SafeSequential, TanhModule, @@ -227,18 +229,24 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): # Exploration wrappers: if cfg.network.noise_type == "ou": - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore = TensorDictSequential( model[0], - annealing_num_steps=1_000_000, - ).to(device) + OrnsteinUhlenbeckProcessModule( + spec=action_spec, + annealing_num_steps=1_000_000, + ).to(device), + ) elif cfg.network.noise_type == "gaussian": - actor_model_explore = AdditiveGaussianWrapper( + actor_model_explore = TensorDictSequential( model[0], - sigma_end=1.0, - sigma_init=1.0, - mean=0.0, - std=0.1, - ).to(device) + AdditiveGaussianModule( + spec=action_spec, + sigma_end=1.0, + sigma_init=1.0, + mean=0.0, + std=0.1, + ).to(device), + ) else: raise NotImplementedError diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e521b9df386..f28fac8e675 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -279,7 +279,7 @@ def compile_rssms(module): if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - policy.step(current_frames) + policy[1].step(current_frames) collector.update_policy_weights_() # Evaluation if (i % eval_iter) == 0: diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 73baa310821..6745b1a079a 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -52,7 +52,7 @@ ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, DreamerActor, IndependentNormal, MLP, @@ -266,13 +266,16 @@ def make_dreamer( test_env=test_env, ) # Exploration noise to be added to the actor_realworld - actor_realworld = AdditiveGaussianWrapper( + actor_realworld = TensorDictSequential( actor_realworld, - sigma_init=1.0, - sigma_end=1.0, - annealing_num_steps=1, - mean=0.0, - std=cfg.networks.exploration_noise, + AdditiveGaussianModule( + spec=test_env.action_spec, + sigma_init=1.0, + sigma_end=1.0, + annealing_num_steps=1, + mean=0.0, + std=cfg.networks.exploration_noise, + ), ) # Make Critic diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index bd44bb0a043..e9de2ac4e14 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -7,7 +7,7 @@ import hydra import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector @@ -18,7 +18,7 @@ from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, ProbabilisticActor, TanhDelta, ValueOperator, @@ -102,10 +102,13 @@ def train(cfg: "DictConfig"): # noqa: F821 return_log_prob=False, ) - policy_explore = AdditiveGaussianWrapper( + policy_explore = TensorDictSequential( policy, - annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), - action_key=env.action_key, + AdditiveGaussianModule( + spec=env.unbatched_action_spec, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + ), ) # Critic @@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() target_net_updater.step() - policy_explore.step(frames=current_frames) # Update exploration annealing + policy_explore[1].step(frames=current_frames) # Update exploration annealing collector.update_policy_weights_() training_time = time.time() - training_start diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index c6b96db9292..eb802f6773d 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -8,10 +8,11 @@ import hydra import torch.cuda +from tensordict.nn import TensorDictSequential from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import OrnsteinUhlenbeckProcessWrapper +from torchrl.modules import OrnsteinUhlenbeckProcessModule from torchrl.record import VideoRecorder from torchrl.record.loggers import get_logger from utils import ( @@ -111,12 +112,15 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.exploration.ou_exploration: if cfg.exploration.gSDE: raise RuntimeError("gSDE and ou_exploration are incompatible") - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore = TensorDictSequential( actor_model_explore, - annealing_num_steps=cfg.exploration.annealing_frames, - sigma=cfg.exploration.ou_sigma, - theta=cfg.exploration.ou_theta, - ).to(device) + OrnsteinUhlenbeckProcessModule( + spec=actor_model_explore.spec, + annealing_num_steps=cfg.exploration.annealing_frames, + sigma=cfg.exploration.ou_sigma, + theta=cfg.exploration.ou_theta, + ).to(device), + ) if device == torch.device("cpu"): # mostly for debugging actor_model_explore.share_memory() diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 5fbc9b032d7..632ee58503d 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() for tensordict in collector: sampling_time = time.time() - sampling_start - exploration_policy.step(tensordict.numel()) + exploration_policy[1].step(tensordict.numel()) # Update weights of the inference policy collector.update_policy_weights_() diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index c597ae205a2..60a4d046355 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -7,6 +7,7 @@ from contextlib import nullcontext import torch +from tensordict.nn import TensorDictSequential from torch import nn, optim from torchrl.collectors import SyncDataCollector @@ -27,7 +28,7 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MLP, SafeModule, SafeSequential, @@ -233,14 +234,16 @@ def make_td3_agent(cfg, train_env, eval_env, device): eval_env.close() # Exploration wrappers: - actor_model_explore = AdditiveGaussianWrapper( + actor_model_explore = TensorDictSequential( model[0], - sigma_init=1, - sigma_end=1, - mean=0, - std=0.1, - spec=action_spec, - ).to(device) + AdditiveGaussianModule( + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device), + ) return model, actor_model_explore diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index 3772eefccde..3dcbd45d30c 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -5,6 +5,7 @@ import functools import torch +from tensordict.nn import TensorDictSequential from torch import nn, optim from torchrl.data.datasets.d4rl import D4RLExperienceReplay @@ -24,7 +25,7 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MLP, SafeModule, SafeSequential, @@ -174,14 +175,16 @@ def make_td3_agent(cfg, train_env, device): del td # Exploration wrappers: - actor_model_explore = AdditiveGaussianWrapper( + actor_model_explore = TensorDictSequential( model[0], - sigma_init=1, - sigma_end=1, - mean=0, - std=0.1, - spec=action_spec, - ).to(device) + AdditiveGaussianModule( + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device), + ) return model, actor_model_explore diff --git a/test/test_collector.py b/test/test_collector.py index 12ec490e7e2..7d7208aead0 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -92,7 +92,7 @@ PARTIAL_MISSING_ERR, RandomPolicy, ) -from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule +from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule # torch.set_default_dtype(torch.double) IS_WINDOWS = sys.platform == "win32" @@ -1291,8 +1291,13 @@ def make_env(): policy_module, in_keys=["observation"], out_keys=["action"] ) copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key]) - policy = TensorDictSequential(policy, copier) - policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) + policy_explore = TensorDictSequential( + policy, + copier, + OrnsteinUhlenbeckProcessModule( + spec=CompositeSpec({key: None for key in policy.out_keys}) + ), + ) collector_kwargs = { "create_env_fn": make_env, @@ -2472,7 +2477,9 @@ def make_env(): obs_spec = dummy_env.observation_spec["observation"] policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) policy = Actor(policy_module, spec=dummy_env.action_spec) - policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) + policy_explore = TensorDictSequential( + policy, OrnsteinUhlenbeckProcessModule(spec=policy.spec) + ) collector_kwargs = { "create_env_fn": make_env, diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 6f81a9748bc..a6f66291719 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -27,7 +27,7 @@ ) from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, DecisionTransformerInferenceWrapper, DTActor, GRUModule, @@ -1363,17 +1363,19 @@ def test_actor_critic_specs(): out_keys=[action_key], ) original_spec = spec.clone() - module = AdditiveGaussianWrapper(policy_module, spec=spec, action_key=action_key) + module = TensorDictSequential( + policy_module, AdditiveGaussianModule(spec=spec, action_key=action_key) + ) value_module = ValueOperator( module=module, in_keys=[("agents", "observation"), action_key], out_keys=[("agents", "state_action_value")], ) assert original_spec == spec - assert module.spec == spec + assert module[1].spec == spec DDPGLoss(actor_network=module, value_network=value_module) assert original_spec == spec - assert module.spec == spec + assert module[1].spec == spec def test_vmapmodule(): diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index c52cb3bd5b2..777a1dbd578 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -188,7 +188,7 @@ # Later, we will see how the target parameters should be updated in TorchRL. # -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential def _init( @@ -722,7 +722,7 @@ def get_env_stats(): ActorCriticWrapper, DdpgMlpActor, DdpgMlpQNet, - OrnsteinUhlenbeckProcessWrapper, + OrnsteinUhlenbeckProcessModule, ProbabilisticActor, TanhDelta, ValueOperator, @@ -781,15 +781,18 @@ def make_ddpg_actor( # Exploration # ~~~~~~~~~~~ # -# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` +# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule` # exploration module, as suggested in the original paper. # Let's define the number of frames before OU noise reaches its minimum value annealing_frames = 1_000_000 -actor_model_explore = OrnsteinUhlenbeckProcessWrapper( +actor_model_explore = TensorDictSequential( actor, - annealing_num_steps=annealing_frames, -).to(device) + OrnsteinUhlenbeckProcessModule( + spec=actor.spec.clone(), + annealing_num_steps=annealing_frames, + ).to(device), +) if device == torch.device("cpu"): actor_model_explore.share_memory() @@ -1173,7 +1176,7 @@ def ceil_div(x, y): ) # update the exploration strategy - actor_model_explore.step(current_frames) + actor_model_explore[1].step(current_frames) collector.shutdown() del collector diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index fb33d520860..4cd35c9bbe7 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -191,8 +191,8 @@ # also palliate to this with its exploration modules. # We will take the example of the :class:`~torchrl.modules.EGreedyModule` # exploration module (check also -# :class:`~torchrl.modules.AdditiveGaussianWrapper` and -# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`). +# :class:`~torchrl.modules.AdditiveGaussianModule` and +# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`). # To see this module in action, let's revert to a deterministic policy: from tensordict.nn import TensorDictSequential diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index b4bc38eb7bf..2600df2f752 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -125,7 +125,7 @@ ) from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MultiAgentMLP, ProbabilisticActor, TanhDelta, @@ -499,7 +499,7 @@ # Since the DDPG policy is deterministic, we need a way to perform exploration during collection. # # For this purpose, we need to append an exploration layer to our policies before passing them to the collector. -# In this case we use a :class:`~torchrl.modules.AdditiveGaussianWrapper`, which adds gaussian noise to our action +# In this case we use a :class:`~torchrl.modules.AdditiveGaussianModule`, which adds gaussian noise to our action # (and clamps it if the noise makes the action out of bounds). # # This exploration wrapper uses a ``sigma`` parameter which is multiplied by the noise to determine its magnitude. @@ -510,13 +510,16 @@ exploration_policies = {} for group, _agents in env.group_map.items(): - exploration_policy = AdditiveGaussianWrapper( + exploration_policy = TensorDictSequential( policies[group], - annealing_num_steps=total_frames - // 2, # Number of frames after which sigma is sigma_end - action_key=(group, "action"), - sigma_init=0.9, # Initial value of the sigma - sigma_end=0.1, # Final value of the sigma + AdditiveGaussianModule( + spec=policies[group].spec, + annealing_num_steps=total_frames + // 2, # Number of frames after which sigma is sigma_end + action_key=(group, "action"), + sigma_init=0.9, # Initial value of the sigma + sigma_end=0.1, # Final value of the sigma + ), ) exploration_policies[group] = exploration_policy