From 6c68f7e0c07526efd8ff2b6e9626a0234326f600 Mon Sep 17 00:00:00 2001 From: Sebastian Dittert Date: Tue, 9 Jan 2024 15:09:31 +0100 Subject: [PATCH] [Algorithm] Update discrete SAC example (#1745) Co-authored-by: vmoens --- .../linux_examples/scripts/run_test.sh | 28 +- examples/discrete_sac/config.yaml | 79 ++-- examples/discrete_sac/discrete_sac.py | 431 ++++++------------ examples/discrete_sac/utils.py | 288 ++++++++++++ torchrl/objectives/sac.py | 2 +- 5 files changed, 493 insertions(+), 335 deletions(-) create mode 100644 examples/discrete_sac/utils.py diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 163fec36721..528acd6d331 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -148,6 +148,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ env.name=Pendulum-v1 \ network.device=cuda:0 \ logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/discrete_sac.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=1 \ + collector.device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + network.device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=CartPole-v1 \ + logger.backend= # logger.record_video=True \ # logger.record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ @@ -246,20 +260,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ logger.record_frames=4 \ buffer.size=120 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ - collector.total_frames=48 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=16 \ - collector.env_per_collector=1 \ - collector.device=cuda:0 \ - optim.batch_size=10 \ - optim.utd_ratio=1 \ - network.device=cuda:0 \ - optim.batch_size=10 \ - optim.utd_ratio=1 \ - replay_buffer.size=120 \ - env.name=Pendulum-v1 \ - logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ collector.total_frames=48 \ optim.batch_size=10 \ diff --git a/examples/discrete_sac/config.yaml b/examples/discrete_sac/config.yaml index df51781bf43..98f908e84d8 100644 --- a/examples/discrete_sac/config.yaml +++ b/examples/discrete_sac/config.yaml @@ -1,37 +1,52 @@ -# Logger -logger: wandb -exp_name: discrete_sac -record_interval: 1 -mode: online -# Environment -env_name: CartPole-v1 -frame_skip: 1 -from_pixels: false -reward_scaling: 1.0 -init_env_steps: 1000 -seed: 42 +# task and env +env: + name: CartPole-v1 + task: "" + exp_name: ${env.name}_DiscreteSAC + library: gym + seed: 42 + max_episode_steps: 500 -# Collector -env_per_collector: 1 -max_frames_per_traj: 500 -total_frames: 1000000 -init_random_frames: 5000 -frames_per_batch: 500 # 500 * env_per_collector +# collector +collector: + total_frames: 25000 + init_random_frames: 1000 + init_env_steps: 1000 + frames_per_batch: 500 + reset_at_each_iter: False + device: cuda:0 + env_per_collector: 1 + num_workers: 1 -# Replay Buffer -prb: 0 -buffer_size: 1000000 +# replay buffer +replay_buffer: + prb: 0 # use prioritized experience replay + size: 1000000 + scratch_dir: ${env.exp_name}_${env.seed} -# Optimization -utd_ratio: 1.0 -gamma: 0.99 -batch_size: 256 -lr: 3.0e-4 -weight_decay: 0.0 -target_update_polyak: 0.995 -target_entropy_weight: 0.2 -# default is 0.98 but needs to be decreased for env -# with small action space +# optim +optim: + utd_ratio: 1.0 + gamma: 0.99 + batch_size: 256 + lr: 3.0e-4 + weight_decay: 0.0 + target_update_polyak: 0.995 + target_entropy_weight: 0.2 + target_entropy: "auto" + loss_function: l2 + # default is 0.98 but needs to be decreased for env + # with small action space -device: cpu +# network +network: + hidden_sizes: [256, 256] + activation: relu + device: "cuda:0" + +# logging +logger: + backend: wandb + mode: online + eval_iter: 5000 diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 29ccd1eca6d..1ff922f41fb 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -2,276 +2,123 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""Discrete SAC Example. +This is a simple self-contained example of a discrete SAC training script. + +It supports gym state environments like CartPole. + +The helper functions are coded in the utils.py associated with this script. +""" + +import time import hydra import numpy as np import torch import torch.cuda import tqdm -from tensordict.nn import InteractionType, TensorDictModule - -from torch import nn, optim -from torchrl.collectors import SyncDataCollector -from torchrl.data import ( - CompositeSpec, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) - -from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.envs import ( - CatTensors, - DMControlEnv, - EnvCreator, - ParallelEnv, - TransformedEnv, -) -from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import MLP, SafeModule -from torchrl.modules.distributions import OneHotCategorical -from torchrl.modules.tensordict_module.actors import ProbabilisticActor - -from torchrl.objectives import DiscreteSACLoss, SoftUpdate from torchrl.record.loggers import generate_exp_name, get_logger - - -def env_maker(cfg, device="cpu"): - lib = cfg.env.library - if lib in ("gym", "gymnasium"): - with set_gym_backend(lib): - return GymEnv( - cfg.env.name, - device=device, - ) - elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) - return TransformedEnv( - env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") - ) - else: - raise NotImplementedError(f"Unknown lib {lib}.") - - -def make_replay_buffer( - prb=False, - buffer_size=1000000, - batch_size=256, - buffer_scratch_dir=None, - device="cpu", - prefetch=3, -): - if prb: - replay_buffer = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - pin_memory=False, - batch_size=batch_size, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - ) - else: - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - batch_size=batch_size, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=buffer_scratch_dir, - device=device, - ), - ) - return replay_buffer +from utils import ( + log_metrics, + make_collector, + make_environment, + make_loss_module, + make_optimizer, + make_replay_buffer, + make_sac_agent, +) @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - - device = ( - torch.device("cuda:0") - if torch.cuda.is_available() - and torch.cuda.device_count() > 0 - and cfg.device == "cuda:0" - else torch.device("cpu") - ) - - exp_name = generate_exp_name("Discrete_SAC", cfg.exp_name) - logger = get_logger( - logger_type=cfg.logger, - logger_name="dSAC_logging", - experiment_name=exp_name, - wandb_kwargs={"mode": cfg.mode}, - ) - - torch.manual_seed(cfg.seed) - np.random.seed(cfg.seed) - - def env_factory(num_workers): - """Creates an instance of the environment.""" - - # 1.2 Create env vector - vec_env = ParallelEnv( - create_env_fn=EnvCreator(lambda cfg=cfg: env_maker(cfg)), - num_workers=num_workers, + device = torch.device(cfg.network.device) + + # Create logger + exp_name = generate_exp_name("DiscreteSAC", cfg.env.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="DiscreteSAC_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, ) - return vec_env + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) - # Sanity check - test_env = env_factory(num_workers=5) - num_actions = test_env.action_spec.space.n + # Create environments + train_env, eval_env = make_environment(cfg) - # Create Agent - # Define Actor Network - in_keys = ["observation"] + # Create agent + model = make_sac_agent(cfg, train_env, eval_env, device) - actor_net_kwargs = { - "num_cells": [256, 256], - "out_features": num_actions, - "activation_class": nn.ReLU, - } + # Create TD3 loss + loss_module, target_net_updater = make_loss_module(cfg, model) - actor_net = MLP(**actor_net_kwargs) - - actor_module = SafeModule( - module=actor_net, - in_keys=in_keys, - out_keys=["logits"], - ) - actor = ProbabilisticActor( - spec=CompositeSpec(action=test_env.action_spec), - module=actor_module, - in_keys=["logits"], - out_keys=["action"], - distribution_class=OneHotCategorical, - distribution_kwargs={}, - default_interaction_type=InteractionType.RANDOM, - return_log_prob=False, - ).to(device) - - # Define Critic Network - qvalue_net_kwargs = { - "num_cells": [256, 256], - "out_features": num_actions, - "activation_class": nn.ReLU, - } - - qvalue_net = MLP( - **qvalue_net_kwargs, - ) + # Create off-policy collector + collector = make_collector(cfg, train_env, model[0]) - qvalue = TensorDictModule( - in_keys=in_keys, - out_keys=["action_value"], - module=qvalue_net, - ).to(device) - - # init nets - with torch.no_grad(): - td = test_env.reset() - td = td.to(device) - actor(td) - qvalue(td) - - del td - test_env.close() - test_env.eval() - - model = torch.nn.ModuleList([actor, qvalue]) - - # Create SAC loss - loss_module = DiscreteSACLoss( - actor_network=model[0], - action_space=test_env.action_spec, - qvalue_network=model[1], - num_actions=num_actions, - num_qvalue_nets=2, - target_entropy_weight=cfg.target_entropy_weight, - loss_function="smooth_l1", - ) - loss_module.make_value_estimator(gamma=cfg.gamma) - - # Define Target Network Updater - target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) - - # Make Off-Policy Collector - collector = SyncDataCollector( - env_factory, - create_env_kwargs={"num_workers": cfg.env_per_collector}, - policy=model[0], - frames_per_batch=cfg.frames_per_batch, - max_frames_per_traj=cfg.max_frames_per_traj, - total_frames=cfg.total_frames, - device=cfg.device, - ) - collector.set_seed(cfg.seed) - - # Make Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( - prb=cfg.prb, - buffer_size=cfg.buffer_size, - batch_size=cfg.batch_size, + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir=cfg.replay_buffer.scratch_dir, device="cpu", ) - # Optimizers - params = list(loss_module.parameters()) - optimizer_actor = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay) - - rewards = [] - rewards_eval = [] + # Create optimizers + optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer( + cfg, loss_module + ) # Main loop + start_time = time.time() collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.total_frames) - r0 = None - loss = None + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_rollout_steps = cfg.env.max_episode_steps + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch - for i, tensordict in enumerate(collector): + sampling_start = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start - # update weights of the inference policy + # Update weights of the inference policy collector.update_policy_weights_() - new_collected_epochs = len(np.unique(tensordict["collector"]["traj_ids"])) - if r0 is None: - r0 = ( - tensordict["next", "reward"].sum().item() - / new_collected_epochs - / cfg.env_per_collector - ) pbar.update(tensordict.numel()) - # extend the replay buffer with the new data - if "mask" in tensordict.keys(): - # if multi-step, a mask is present to help filter padded values - current_frames = tensordict["mask"].sum() - tensordict = tensordict[tensordict.get("mask").squeeze(-1)] - else: - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - total_collected_epochs = tensordict["collector"]["traj_ids"].max().item() - # optimization steps - if collected_frames >= cfg.init_random_frames: + # Optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: ( - total_losses, actor_losses, q_losses, alpha_losses, - alphas, - entropies, - ) = ([], [], [], [], [], []) - for _ in range(cfg.frames_per_batch * int(cfg.utd_ratio)): - # sample from replay buffer + ) = ([], [], []) + for _ in range(num_updates): + # Sample from replay buffer sampled_tensordict = replay_buffer.sample() if sampled_tensordict.device != device: sampled_tensordict = sampled_tensordict.to( @@ -280,80 +127,88 @@ def env_factory(num_workers): else: sampled_tensordict = sampled_tensordict.clone() - loss_td = loss_module(sampled_tensordict) + # Compute loss + loss_out = loss_module(sampled_tensordict) + + actor_loss, q_loss, alpha_loss = ( + loss_out["loss_actor"], + loss_out["loss_qvalue"], + loss_out["loss_alpha"], + ) - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - alpha_loss = loss_td["loss_alpha"] + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_losses.append(q_loss.item()) - loss = actor_loss + q_loss + alpha_loss + # Update actor optimizer_actor.zero_grad() - loss.backward() + actor_loss.backward() optimizer_actor.step() - # update qnet_target params - target_net_updater.step() + actor_losses.append(actor_loss.item()) - # update priority - if cfg.prb: - replay_buffer.update_priority(sampled_tensordict) + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() - total_losses.append(loss.item()) - actor_losses.append(actor_loss.item()) - q_losses.append(q_loss.item()) alpha_losses.append(alpha_loss.item()) - alphas.append(loss_td["alpha"].item()) - entropies.append(loss_td["entropy"].item()) - rewards.append( - ( - i, - tensordict["next", "reward"].sum().item() - / cfg.env_per_collector - / new_collected_epochs, - ) - ) - metrics = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - "epochs": total_collected_epochs, - } - - if loss is not None: - metrics.update( - { - "total_loss": np.mean(total_losses), - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - "alpha_loss": np.mean(alpha_losses), - "alpha": np.mean(alphas), - "entropy": np.mean(entropies), - } - ) + # Update target params + target_net_updater.step() - with set_exploration_type( - ExplorationType.RANDOM - ), torch.no_grad(): # TODO: exploration mode to mean causes nans - - eval_rollout = test_env.rollout( - max_steps=cfg.max_frames_per_traj, - policy=actor, - auto_cast_to_device=True, - ).clone() - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - metrics.update({"test_reward": rewards_eval[-1][1]}) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - # log metrics - for key, value in metrics.items(): - logger.log_scalar(key, value, step=collected_frames) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses) + metrics_to_log["train/a_loss"] = np.mean(actor_losses) + metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses) + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/discrete_sac/utils.py b/examples/discrete_sac/utils.py new file mode 100644 index 00000000000..f7d581ce7e2 --- /dev/null +++ b/examples/discrete_sac/utils.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import tempfile +from contextlib import nullcontext + +import torch +from tensordict.nn import InteractionType, TensorDictModule + +from torch import nn, optim +from torchrl.collectors import SyncDataCollector +from torchrl.data import ( + CompositeSpec, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + InitTracker, + ParallelEnv, + RewardSum, + StepCounter, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import MLP, SafeModule +from torchrl.modules.distributions import OneHotCategorical + +from torchrl.modules.tensordict_module.actors import ProbabilisticActor +from torchrl.objectives import SoftUpdate +from torchrl.objectives.sac import DiscreteSACLoss + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps): + transformed_env = TransformedEnv( + env, + Compose( + StepCounter(max_steps=max_episode_steps), + InitTracker(), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg): + """Make environments for training and evaluation.""" + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms( + parallel_env, max_episode_steps=cfg.env.max_episode_steps + ) + + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + ), + train_env.transform.clone(), + ) + return train_env, eval_env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, train_env, actor_model_explore): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + reset_at_each_iter=cfg.collector.reset_at_each_iter, + device=cfg.collector.device, + ) + collector.set_seed(cfg.env.seed) + return collector + + +def make_replay_buffer( + batch_size, + prb=False, + buffer_size=1000000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, +): + with ( + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) + ) as scratch_dir: + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + return replay_buffer + + +# ==================================================================== +# Model +# ----- + + +def make_sac_agent(cfg, train_env, eval_env, device): + """Make discrete SAC agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + # Define Actor Network + in_keys = ["observation"] + + actor_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": get_activation(cfg), + } + + actor_net = MLP(**actor_net_kwargs) + + actor_module = SafeModule( + module=actor_net, + in_keys=in_keys, + out_keys=["logits"], + ) + actor = ProbabilisticActor( + spec=CompositeSpec(action=eval_env.action_spec), + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + distribution_kwargs={}, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": get_activation(cfg), + } + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = TensorDictModule( + in_keys=in_keys, + out_keys=["action_value"], + module=qvalue_net, + ) + + model = torch.nn.ModuleList([actor, qvalue]).to(device) + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + eval_env.close() + + return model + + +# ==================================================================== +# Discrete SAC Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create discrete SAC loss + loss_module = DiscreteSACLoss( + actor_network=model[0], + qvalue_network=model[1], + num_actions=model[0].spec["action"].space.n, + num_qvalue_nets=2, + loss_function=cfg.optim.loss_function, + target_entropy_weight=cfg.optim.target_entropy_weight, + delay_qvalue=True, + ) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) + return loss_module, target_net_updater + + +def make_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + ) + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=3.0e-4, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 618e8b9a554..4da874148e7 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -785,7 +785,7 @@ class DiscreteSACLoss(LossModule): :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). num_actions (int, optional): number of actions in the action space. - To be provided if target_entropy is ste to "auto". + To be provided if target_entropy is set to "auto". num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", "l1", Default is "smooth_l1".