From 9045977b6ca6ed8089864cf43646388b36ae7f9c Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 12 Sep 2023 12:18:10 +0200 Subject: [PATCH 01/34] fix --- examples/sac/config.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 22cba615d30..1777cae2617 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -2,15 +2,15 @@ env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-SAC" + exp_name: "HalfCheetah-SAC-ICLR" library: gym frame_skip: 1 seed: 1 -# Collection +# collector collector: total_frames: 1000000 - init_random_frames: 10000 + init_random_frames: 25000 frames_per_batch: 1000 max_frames_per_traj: 1000 init_env_steps: 1000 @@ -19,13 +19,13 @@ collector: env_per_collector: 1 num_workers: 1 -# Replay Buffer +# replay buffer replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay -# Optimization -optimization: +# optim +optim: utd_ratio: 1.0 gamma: 0.99 loss_function: smooth_l1 @@ -35,7 +35,7 @@ optimization: batch_size: 256 target_update_polyak: 0.995 -# Algorithm +# network network: hidden_sizes: [256, 256] activation: relu @@ -43,7 +43,7 @@ network: scale_lb: 0.1 device: "cuda:0" -# Logging +# logging logger: backend: wandb mode: online From 7d42ba537a927f999b6d9c5a5b686545c241df8e Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 12 Sep 2023 15:54:52 +0200 Subject: [PATCH 02/34] update optimizer --- examples/sac/config.yaml | 2 +- examples/sac/sac.py | 145 ++++++++++++++++++++++++--------------- examples/sac/utils.py | 35 ++++++---- 3 files changed, 112 insertions(+), 70 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 1777cae2617..466aedf1532 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -28,7 +28,7 @@ replay_buffer: optim: utd_ratio: 1.0 gamma: 0.99 - loss_function: smooth_l1 + loss_function: l2 lr: 3e-4 weight_decay: 2e-4 lr_scheduler: "" diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 17b997cfda6..2df639dde49 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -11,13 +11,15 @@ The helper functions are coded in the utils.py associated with this script. """ +import time + import hydra import numpy as np import torch import torch.cuda import tqdm - +from tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -35,6 +37,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + # Create Logger exp_name = generate_exp_name("SAC", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -50,81 +53,98 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create Environments train_env, eval_env = make_environment(cfg) + # Create Agent model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) - # Create TD3 loss + # Create SAC loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Make Off-Policy Collector + # Create Off-Policy Collector collector = make_collector(cfg, train_env, exploration_policy) - # Make Replay Buffer + # Create Replay Buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, device=device, ) - # Make Optimizers - optimizer = make_sac_optimizer(cfg, loss_module) - - rewards = [] - rewards_eval = [] + # Create Optimizers + optimizer_actor, optimizer_critic, optimizer_alpha = make_sac_optimizer( + cfg, loss_module + ) # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector eval_iter = cfg.logger.eval_iter frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip + sampling_start = time.time() for i, tensordict in enumerate(collector): + sampling_time = time.time() - sampling_start + # update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() pbar.update(tensordict.numel()) - tensordict = tensordict.view(-1) + tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + # add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames # optimization steps + training_start = time.time() if collected_frames >= init_random_frames: - (actor_losses, q_losses, alpha_losses) = ([], [], []) - for _ in range(num_updates): + losses = TensorDict( + {}, + batch_size=[ + num_updates, + ], + ) + for i in range(num_updates): # sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() + # compute loss loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] - loss = actor_loss + q_loss + alpha_loss - optimizer.zero_grad() - loss.backward() - optimizer.step() + # update actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - alpha_losses.append(alpha_loss.item()) + # update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() + + losses[i] = loss_td.select( + "loss_actor", "loss_qvalue", "loss_alpha" + ).detach() # update qnet_target params target_net_updater.step() @@ -132,48 +152,61 @@ def main(cfg: "DictConfig"): # noqa: F821 # update priority if prb: replay_buffer.update_priority(sampled_tensordict) - - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) - ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - "alpha_loss": np.mean(alpha_losses), - "alpha": loss_td["alpha"], - "entropy": loss_td["entropy"], - } + training_time = time.time() - training_start + episode_rewards = tensordict["next", "episode_reward"][ + tensordict["next", "done"] + ] + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + logger.log_scalar( + "train/reward", episode_rewards.mean().item(), collected_frames + ) + logger.log_scalar( + "train/episode_length", + episode_length.sum().item() / len(episode_length), + collected_frames, + ) + if collected_frames >= init_random_frames: + logger.log_scalar( + "train/q_loss", losses.get("loss_qvalue").mean(), step=collected_frames + ) + logger.log_scalar( + "train/a_loss", losses.get("loss_actor").mean(), step=collected_frames ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) + logger.log_scalar( + "train/alpha_loss", + losses.get("loss_alpha").mean(), + step=collected_frames, + ) + logger.log_scalar("train/alpha", loss_td["alpha"], step=collected_frames) + logger.log_scalar( + "train/entropy", loss_td["entropy"], step=collected_frames + ) + logger.log_scalar("train/sampling_time", sampling_time, collected_frames) + logger.log_scalar("train/training_time", training_time, collected_frames) + + # evaluation + if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + logger.log_scalar("eval/reward", eval_reward, step=collected_frames) + logger.log_scalar("eval/time", eval_time, step=collected_frames) collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 9c6f71ffa6c..cab7a72c6e6 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -7,7 +7,7 @@ from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling +from torchrl.envs.transforms import RewardScaling, RewardSum from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal @@ -29,7 +29,8 @@ def apply_env_transforms(env, reward_scaling=1.0): env, Compose( RewardScaling(loc=0.0, scale=reward_scaling), - DoubleToFloat(), + DoubleToFloat("observation"), + RewardSum(), ), ) return transformed_env @@ -65,6 +66,7 @@ def make_collector(cfg, train_env, actor_model_explore): collector = SyncDataCollector( train_env, actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, @@ -214,24 +216,31 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, delay_actor=False, delay_qvalue=True, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak - ) + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) return loss_module, target_net_updater def make_sac_optimizer(cfg, loss_module): - """Make SAC optimizer.""" - optimizer = optim.Adam( - loss_module.parameters(), - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + ) + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=cfg.optim.lr, ) - return optimizer + return optimizer_actor, optimizer_critic, optimizer_alpha From 738c6dffb8c83656cd340599d05e4c49da244532 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 13 Sep 2023 13:05:48 +0200 Subject: [PATCH 03/34] fix --- examples/sac/config.yaml | 7 +++---- examples/sac/sac.py | 1 - torchrl/objectives/sac.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 466aedf1532..1045dc4889a 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -9,9 +9,9 @@ env: # collector collector: - total_frames: 1000000 + total_frames: 3_000_000 init_random_frames: 25000 - frames_per_batch: 1000 + frames_per_batch: 1 max_frames_per_traj: 1000 init_env_steps: 1000 async_collection: 1 @@ -30,8 +30,7 @@ optim: gamma: 0.99 loss_function: l2 lr: 3e-4 - weight_decay: 2e-4 - lr_scheduler: "" + weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 2df639dde49..fddf6e9e2c9 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -188,7 +188,6 @@ def main(cfg: "DictConfig"): # noqa: F821 logger.log_scalar("train/training_time", training_time, collected_frames) # evaluation - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index de4908d1335..591dd831044 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -701,7 +701,7 @@ def _qvalue_v2_loss( pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ).mean(0) + ).sum(0) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata @@ -749,7 +749,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) + alpha_loss = -self.log_alpha.exp() * (log_prob + self.target_entropy) else: # placeholder alpha_loss = torch.zeros_like(log_prob) From e3e4ced787b060b9916c0d6f632931e7bc17fd1f Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 09:52:26 +0200 Subject: [PATCH 04/34] add init alpha option --- examples/sac/config.yaml | 1 + examples/sac/utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 1045dc4889a..6f2918cebcd 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -33,6 +33,7 @@ optim: weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 + alpha_init: 0.2 # network network: diff --git a/examples/sac/utils.py b/examples/sac/utils.py index cab7a72c6e6..7bf6ff43638 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -219,6 +219,7 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, delay_actor=False, delay_qvalue=True, + alpha_init=cfg.optim.alpha_init, ) loss_module.make_value_estimator(gamma=cfg.optim.gamma) From 28973405f8100a46ee1b9c8853c5b46d076dc447 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 10:02:00 +0200 Subject: [PATCH 05/34] logalpha fix --- torchrl/objectives/sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 591dd831044..ab39837a459 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -749,7 +749,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha.exp() * (log_prob + self.target_entropy) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) else: # placeholder alpha_loss = torch.zeros_like(log_prob) From f83c6e6ee2fb300abb84c85c020c2301e3e39cdb Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 15:18:53 +0200 Subject: [PATCH 06/34] naming fixes --- examples/sac/sac.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/sac/sac.py b/examples/sac/sac.py index fddf6e9e2c9..fb6c1b1cbab 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -37,7 +37,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) - # Create Logger + # Create logger exp_name = generate_exp_name("SAC", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -51,19 +51,19 @@ def main(cfg: "DictConfig"): # noqa: F821 torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - # Create Environments + # Create environments train_env, eval_env = make_environment(cfg) - # Create Agent + # Create agent model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) # Create SAC loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Create Off-Policy Collector + # Create off-policy collector collector = make_collector(cfg, train_env, exploration_policy) - # Create Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, @@ -71,7 +71,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device=device, ) - # Create Optimizers + # Create optimizers optimizer_actor, optimizer_critic, optimizer_alpha = make_sac_optimizer( cfg, loss_module ) @@ -96,18 +96,18 @@ def main(cfg: "DictConfig"): # noqa: F821 for i, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start - # update weights of the inference policy + # Update weights of the inference policy collector.update_policy_weights_() pbar.update(tensordict.numel()) tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # add to replay buffer + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - # optimization steps + # Optimization steps training_start = time.time() if collected_frames >= init_random_frames: losses = TensorDict( @@ -117,27 +117,27 @@ def main(cfg: "DictConfig"): # noqa: F821 ], ) for i in range(num_updates): - # sample from replay buffer + # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() - # compute loss + # Compute loss loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] - # update actor + # Update actor optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() - # update critic + # Update critic optimizer_critic.zero_grad() q_loss.backward() optimizer_critic.step() - # update alpha + # Update alpha optimizer_alpha.zero_grad() alpha_loss.backward() optimizer_alpha.step() @@ -146,16 +146,19 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_actor", "loss_qvalue", "loss_alpha" ).detach() - # update qnet_target params + # Update qnet_target params target_net_updater.step() - # update priority + # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) + training_time = time.time() - training_start episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] + + # Logging if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][ tensordict["next", "done"] @@ -187,7 +190,7 @@ def main(cfg: "DictConfig"): # noqa: F821 logger.log_scalar("train/sampling_time", sampling_time, collected_frames) logger.log_scalar("train/training_time", training_time, collected_frames) - # evaluation + # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() From e58b9b0c783aeb42b319e4e18bf3511a9e5d0ef6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 14 Sep 2023 15:19:41 +0200 Subject: [PATCH 07/34] fix --- examples/sac/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 6f2918cebcd..315cbe6923b 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,4 +1,4 @@ -# Environment +# environment and task env: name: HalfCheetah-v3 task: "" From 1956d806505f35b82b0c236904c7d24d31b856a1 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 15 Sep 2023 09:36:59 +0200 Subject: [PATCH 08/34] update logging small fixes --- examples/sac/config.yaml | 4 ++-- examples/sac/sac.py | 43 ++++++++++++++++------------------------ examples/sac/utils.py | 32 ++++++++++++++++++++---------- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 315cbe6923b..ec55e26de28 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -2,7 +2,7 @@ env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-SAC-ICLR" + exp_name: "HalfCheetah-SAC" library: gym frame_skip: 1 seed: 1 @@ -11,7 +11,7 @@ env: collector: total_frames: 3_000_000 init_random_frames: 25000 - frames_per_batch: 1 + frames_per_batch: 1000 max_frames_per_traj: 1000 init_env_steps: 1000 async_collection: 1 diff --git a/examples/sac/sac.py b/examples/sac/sac.py index fb6c1b1cbab..8a9d040e5d1 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, make_environment, make_loss_module, @@ -159,36 +160,23 @@ def main(cfg: "DictConfig"): # noqa: F821 ] # Logging + metrics_to_log = {} if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][ tensordict["next", "done"] ] - logger.log_scalar( - "train/reward", episode_rewards.mean().item(), collected_frames - ) - logger.log_scalar( - "train/episode_length", - episode_length.sum().item() / len(episode_length), - collected_frames, + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) if collected_frames >= init_random_frames: - logger.log_scalar( - "train/q_loss", losses.get("loss_qvalue").mean(), step=collected_frames - ) - logger.log_scalar( - "train/a_loss", losses.get("loss_actor").mean(), step=collected_frames - ) - logger.log_scalar( - "train/alpha_loss", - losses.get("loss_alpha").mean(), - step=collected_frames, - ) - logger.log_scalar("train/alpha", loss_td["alpha"], step=collected_frames) - logger.log_scalar( - "train/entropy", loss_td["entropy"], step=collected_frames - ) - logger.log_scalar("train/sampling_time", sampling_time, collected_frames) - logger.log_scalar("train/training_time", training_time, collected_frames) + metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item() + metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() + metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() + metrics_to_log["train/alpha"] = loss_td["alpha"].item() + metrics_to_log["train/entropy"] = loss_td["entropy"].item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: @@ -202,8 +190,11 @@ def main(cfg: "DictConfig"): # noqa: F821 ) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - logger.log_scalar("eval/reward", eval_reward, step=collected_frames) - logger.log_scalar("eval/time", eval_time, step=collected_frames) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() end_time = time.time() diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 7bf6ff43638..d8bbe0e0b5f 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -116,17 +116,6 @@ def make_replay_buffer( # ----- -def get_activation(cfg): - if cfg.network.activation == "relu": - return nn.ReLU - elif cfg.network.activation == "tanh": - return nn.Tanh - elif cfg.network.activation == "leaky_relu": - return nn.LeakyReLU - else: - raise NotImplementedError - - def make_sac_agent(cfg, train_env, eval_env, device): """Make SAC agent.""" # Define Actor Network @@ -245,3 +234,24 @@ def make_sac_optimizer(cfg, loss_module): lr=cfg.optim.lr, ) return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError From 8a6030132ba5fb4eef78e1c41a41883f23e3ef1e Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 15 Sep 2023 13:01:13 +0200 Subject: [PATCH 09/34] add wd --- examples/sac/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index ec55e26de28..aac93657460 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -30,7 +30,7 @@ optim: gamma: 0.99 loss_function: l2 lr: 3e-4 - weight_decay: 0.0 + weight_decay: 1.0e-4 batch_size: 256 target_update_polyak: 0.995 alpha_init: 0.2 From e56b46bda5674f36426c086105ad8a761cb04324 Mon Sep 17 00:00:00 2001 From: BY571 Date: Sun, 17 Sep 2023 14:34:18 +0200 Subject: [PATCH 10/34] add eps --- examples/sac/config.yaml | 5 +++-- examples/sac/utils.py | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index aac93657460..a0a103580b3 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -30,10 +30,11 @@ optim: gamma: 0.99 loss_function: l2 lr: 3e-4 - weight_decay: 1.0e-4 + weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 - alpha_init: 0.2 + alpha_init: 1.0 + adam_eps: 1.0e-4 # network network: diff --git a/examples/sac/utils.py b/examples/sac/utils.py index d8bbe0e0b5f..10c20934b08 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -222,12 +222,16 @@ def make_sac_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) optimizer_critic = optim.Adam( critic_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) optimizer_alpha = optim.Adam( [loss_module.log_alpha], From 220861afd1935290e86607ed3dcaade70ec04bc4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 18 Sep 2023 10:52:35 +0200 Subject: [PATCH 11/34] no eps --- examples/sac/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index a0a103580b3..fc2c3fd9369 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -34,7 +34,7 @@ optim: batch_size: 256 target_update_polyak: 0.995 alpha_init: 1.0 - adam_eps: 1.0e-4 + adam_eps: 1.0e-8 # network network: From 0974772d18ac3f457b63b0ec085e431ca6208e40 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 18 Sep 2023 11:56:50 +0200 Subject: [PATCH 12/34] undetach q at actorloss --- torchrl/objectives/sac.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index ab39837a459..47646626f8c 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -575,7 +575,8 @@ def _actor_loss( td_q = tensordict.select(*self.qvalue_network.in_keys) td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( - td_q, self._cached_detached_qvalue_params # should we clone? + td_q, + self.qvalue_network_params, # _cached_detached_qvalue_params # should we clone? ) min_q_logprob = ( td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) @@ -697,11 +698,14 @@ def _qvalue_v2_loss( -1 ) td_error = abs(pred_val - target_value) - loss_qval = distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, - ).sum(0) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).sum(0) + * 0.5 + ) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata From ac54930e70d2447b91f0b7df923120b2cab5fde3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 19 Sep 2023 10:30:55 +0200 Subject: [PATCH 13/34] tests --- examples/sac/config.yaml | 6 +++--- torchrl/objectives/sac.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index fc2c3fd9369..d1a8b6eaf63 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -29,11 +29,11 @@ optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 3e-4 + lr: 1.0e-3 weight_decay: 0.0 - batch_size: 256 + batch_size: 100 target_update_polyak: 0.995 - alpha_init: 1.0 + alpha_init: 0.2 adam_eps: 1.0e-8 # network diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 47646626f8c..cc24844ad67 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -576,7 +576,7 @@ def _actor_loss( td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( td_q, - self.qvalue_network_params, # _cached_detached_qvalue_params # should we clone? + self._cached_detached_qvalue_params, # should we clone? ) min_q_logprob = ( td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) @@ -698,14 +698,11 @@ def _qvalue_v2_loss( -1 ) td_error = abs(pred_val - target_value) - loss_qval = ( - distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, - ).sum(0) - * 0.5 - ) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).mean(0) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata From 1cfc8219f2d2fd2ada252b5a44ca1c18abe8e9ba Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 19 Sep 2023 10:49:06 +0200 Subject: [PATCH 14/34] update test --- .github/unittest/linux_examples/scripts/run_test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index d81e90fdd42..d36af4d2cce 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -239,8 +239,8 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.num_workers=2 \ collector.env_per_collector=1 \ collector.collector_device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= From 500bd5d95370ebc3c1a0d9c0019a180a0f23ea05 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 19 Sep 2023 11:29:15 +0200 Subject: [PATCH 15/34] update test --- .github/unittest/linux_examples/scripts/run_test.sh | 4 ++-- examples/sac/config.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index d36af4d2cce..8d8bd007d1d 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -115,8 +115,8 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.num_workers=4 \ collector.env_per_collector=2 \ collector.collector_device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index d1a8b6eaf63..b82d1861893 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -29,9 +29,9 @@ optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 1.0e-3 + lr: 3.0e-4 weight_decay: 0.0 - batch_size: 100 + batch_size: 256 target_update_polyak: 0.995 alpha_init: 0.2 adam_eps: 1.0e-8 From 4b234463d40583ac2830d0d5184484b0d242f5e3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Sep 2023 11:38:55 +0200 Subject: [PATCH 16/34] update config, test add set_gym_backend --- .github/unittest/linux_examples/scripts/run_test.sh | 2 -- examples/sac/config.yaml | 2 -- examples/sac/utils.py | 7 +++++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 8d8bd007d1d..ef52710b751 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -112,7 +112,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ - collector.num_workers=4 \ collector.env_per_collector=2 \ collector.collector_device=cuda:0 \ optim.batch_size=10 \ @@ -236,7 +235,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ - collector.num_workers=2 \ collector.env_per_collector=1 \ collector.collector_device=cuda:0 \ optim.batch_size=10 \ diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index b82d1861893..9576180e61a 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -14,10 +14,8 @@ collector: frames_per_batch: 1000 max_frames_per_traj: 1000 init_env_steps: 1000 - async_collection: 1 collector_device: cpu env_per_collector: 1 - num_workers: 1 # replay buffer replay_buffer: diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 10c20934b08..db5ea365a70 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -6,7 +6,7 @@ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv -from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.transforms import RewardScaling, RewardSum from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator @@ -21,7 +21,10 @@ def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) + with set_gym_backend("gym"): + return GymEnv( + task, device=device, frame_skip=frame_skip, from_pixels=from_pixels + ) def apply_env_transforms(env, reward_scaling=1.0): From 567cd2be49a9b3e57503138309daad680a6d398d Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Sep 2023 11:42:52 +0200 Subject: [PATCH 17/34] update header --- examples/sac/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/sac/utils.py b/examples/sac/utils.py index db5ea365a70..9cdb94521bb 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import torch from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor From ede6064f6c1a654f399152539231358469a69365 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 22 Sep 2023 14:43:41 +0200 Subject: [PATCH 18/34] fix max episode steps --- examples/sac/config.yaml | 1 + examples/sac/sac.py | 11 ++++++----- examples/sac/utils.py | 22 ++++++++++++++++++---- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 9576180e61a..65fb1151bc5 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -5,6 +5,7 @@ env: exp_name: "HalfCheetah-SAC" library: gym frame_skip: 1 + max_episode_steps: 1_000_000 seed: 1 # collector diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 8a9d040e5d1..c2deebc2c25 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -155,16 +155,17 @@ def main(cfg: "DictConfig"): # noqa: F821 replay_buffer.update_priority(sampled_tensordict) training_time = time.time() - training_start - episode_rewards = tensordict["next", "episode_reward"][ + episode_end = ( tensordict["next", "done"] - ] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] # Logging metrics_to_log = {} if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] + episode_length = tensordict["next", "step_count"][episode_end] metrics_to_log["train/reward"] = episode_rewards.mean().item() metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( episode_length diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 9cdb94521bb..ea8ece21682 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -25,10 +25,16 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): +def env_maker( + task, frame_skip=1, device="cpu", from_pixels=False, max_episode_steps=1000 +): with set_gym_backend("gym"): return GymEnv( - task, device=device, frame_skip=frame_skip, from_pixels=from_pixels + task, + device=device, + frame_skip=frame_skip, + from_pixels=from_pixels, + max_episode_steps=max_episode_steps, ) @@ -48,7 +54,11 @@ def make_environment(cfg): """Make environments for training and evaluation.""" parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator( + lambda: env_maker( + task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps + ) + ), ) parallel_env.set_seed(cfg.env.seed) @@ -57,7 +67,11 @@ def make_environment(cfg): eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda: env_maker(task=cfg.env.name)), + EnvCreator( + lambda: env_maker( + task=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps + ) + ), ), train_env.transform.clone(), ) From b2a04e6d3c49f4cc708751f7682ace34c8048d90 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 26 Sep 2023 08:31:11 +0200 Subject: [PATCH 19/34] update objective --- torchrl/objectives/sac.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index cc24844ad67..750d568e992 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -698,11 +698,15 @@ def _qvalue_v2_loss( -1 ) td_error = abs(pred_val - target_value) - loss_qval = distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, - ).mean(0) + loss_qval = ( + distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ) + .mean(-1) + .sum(0) + ) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata From 1bf7382787432132147d80f5ea6a46ad122a2d6d Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 26 Sep 2023 08:35:46 +0200 Subject: [PATCH 20/34] update objective --- torchrl/objectives/sac.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 750d568e992..983a6b8c46f 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -698,15 +698,11 @@ def _qvalue_v2_loss( -1 ) td_error = abs(pred_val - target_value) - loss_qval = ( - distance_loss( - pred_val, - target_value.expand_as(pred_val), - loss_function=self.loss_function, - ) - .mean(-1) - .sum(0) - ) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).sum(0) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata From 01d6e560f6b588a3ca6f75dd085e6593853fe079 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 26 Sep 2023 22:06:59 +0200 Subject: [PATCH 21/34] sep critic opti --- examples/sac/config.yaml | 2 +- examples/sac/sac.py | 27 ++++++++++++++++++--------- examples/sac/utils.py | 24 +++++++++++++++++++++--- torchrl/objectives/sac.py | 7 ++++--- 4 files changed, 44 insertions(+), 16 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 65fb1151bc5..b5b1828e9f2 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -28,7 +28,7 @@ optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 3.0e-4 + lr: 1.0e-3 weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 diff --git a/examples/sac/sac.py b/examples/sac/sac.py index c2deebc2c25..8d50c4ac55e 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -73,9 +73,12 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - optimizer_actor, optimizer_critic, optimizer_alpha = make_sac_optimizer( - cfg, loss_module - ) + ( + optimizer_actor, + optimizer_critic1, + optimizer_critic2, + optimizer_alpha, + ) = make_sac_optimizer(cfg, loss_module) # Main loop start_time = time.time() @@ -125,7 +128,8 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] + q1_loss = loss_td["loss_qvalue1"] + q2_loss = loss_td["loss_qvalue2"] alpha_loss = loss_td["loss_alpha"] # Update actor @@ -134,9 +138,13 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor.step() # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() + optimizer_critic1.zero_grad() + q1_loss.backward(retain_graph=True) + optimizer_critic1.step() + + optimizer_critic2.zero_grad() + q2_loss.backward() + optimizer_critic2.step() # Update alpha optimizer_alpha.zero_grad() @@ -144,7 +152,7 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_alpha.step() losses[i] = loss_td.select( - "loss_actor", "loss_qvalue", "loss_alpha" + "loss_actor", "loss_qvalue1", "loss_qvalue2", "loss_alpha" ).detach() # Update qnet_target params @@ -171,7 +179,8 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item() + metrics_to_log["train/q1_loss"] = losses.get("loss_qvalue1").mean().item() + metrics_to_log["train/q2_loss"] = losses.get("loss_qvalue2").mean().item() metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() metrics_to_log["train/alpha"] = loss_td["alpha"].item() diff --git a/examples/sac/utils.py b/examples/sac/utils.py index ea8ece21682..57f046ac3f1 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -239,8 +239,20 @@ def make_loss_module(cfg, model): return loss_module, target_net_updater +def split_critic_params(critic_params): + critic1_params = [] + critic2_params = [] + + for param in critic_params: + data1, data2 = param.data.chunk(2, dim=0) + critic1_params.append(nn.Parameter(data1)) + critic2_params.append(nn.Parameter(data2)) + return critic1_params, critic2_params + + def make_sac_optimizer(cfg, loss_module): critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + critic1_params, critic2_params = split_critic_params(critic_params) actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( @@ -249,8 +261,14 @@ def make_sac_optimizer(cfg, loss_module): weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, ) - optimizer_critic = optim.Adam( - critic_params, + optimizer_critic1 = optim.Adam( + critic1_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_critic2 = optim.Adam( + critic2_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, @@ -259,7 +277,7 @@ def make_sac_optimizer(cfg, loss_module): [loss_module.log_alpha], lr=cfg.optim.lr, ) - return optimizer_actor, optimizer_critic, optimizer_alpha + return optimizer_actor, optimizer_critic1, optimizer_critic2, optimizer_alpha # ==================================================================== diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 983a6b8c46f..e3619712a49 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -536,7 +536,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) - if (loss_actor.shape != loss_qvalue.shape) or ( + if (loss_actor.shape != loss_qvalue.sum(0).shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape ): raise RuntimeError( @@ -547,7 +547,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = -metadata_actor["log_prob"].mean() out = { "loss_actor": loss_actor.mean(), - "loss_qvalue": loss_qvalue.mean(), + "loss_qvalue1": loss_qvalue[0].mean(), + "loss_qvalue2": loss_qvalue[1].mean(), "loss_alpha": loss_alpha.mean(), "alpha": self._alpha, "entropy": entropy, @@ -702,7 +703,7 @@ def _qvalue_v2_loss( pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ).sum(0) + ) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata From 522d061d2a348fe5cdf541d8d612195bfc007112 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 27 Sep 2023 08:29:48 +0200 Subject: [PATCH 22/34] fixes --- examples/sac/config.yaml | 9 +++++---- examples/sac/sac.py | 21 +++++++-------------- torchrl/objectives/sac.py | 7 +++---- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index b5b1828e9f2..11d3b014478 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,22 +1,23 @@ # environment and task env: - name: HalfCheetah-v3 + name: Hopper-v3 task: "" - exp_name: "HalfCheetah-SAC" + exp_name: "Hopper-SAC" library: gym frame_skip: 1 - max_episode_steps: 1_000_000 + max_episode_steps: 1001 seed: 1 # collector collector: - total_frames: 3_000_000 + total_frames: 1_000_000 init_random_frames: 25000 frames_per_batch: 1000 max_frames_per_traj: 1000 init_env_steps: 1000 collector_device: cpu env_per_collector: 1 + reset_at_each_iter: true # replay buffer replay_buffer: diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 8d50c4ac55e..6e473f6e559 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -75,8 +75,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers ( optimizer_actor, - optimizer_critic1, - optimizer_critic2, + optimizer_critic, optimizer_alpha, ) = make_sac_optimizer(cfg, loss_module) @@ -128,8 +127,7 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] - q1_loss = loss_td["loss_qvalue1"] - q2_loss = loss_td["loss_qvalue2"] + q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] # Update actor @@ -138,13 +136,9 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor.step() # Update critic - optimizer_critic1.zero_grad() - q1_loss.backward(retain_graph=True) - optimizer_critic1.step() - - optimizer_critic2.zero_grad() - q2_loss.backward() - optimizer_critic2.step() + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() # Update alpha optimizer_alpha.zero_grad() @@ -152,7 +146,7 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_alpha.step() losses[i] = loss_td.select( - "loss_actor", "loss_qvalue1", "loss_qvalue2", "loss_alpha" + "loss_actor", "loss_qvalue", "loss_alpha" ).detach() # Update qnet_target params @@ -179,8 +173,7 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q1_loss"] = losses.get("loss_qvalue1").mean().item() - metrics_to_log["train/q2_loss"] = losses.get("loss_qvalue2").mean().item() + metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item() metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() metrics_to_log["train/alpha"] = loss_td["alpha"].item() diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index e3619712a49..983a6b8c46f 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -536,7 +536,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) - if (loss_actor.shape != loss_qvalue.sum(0).shape) or ( + if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape ): raise RuntimeError( @@ -547,8 +547,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = -metadata_actor["log_prob"].mean() out = { "loss_actor": loss_actor.mean(), - "loss_qvalue1": loss_qvalue[0].mean(), - "loss_qvalue2": loss_qvalue[1].mean(), + "loss_qvalue": loss_qvalue.mean(), "loss_alpha": loss_alpha.mean(), "alpha": self._alpha, "entropy": entropy, @@ -703,7 +702,7 @@ def _qvalue_v2_loss( pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ) + ).sum(0) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata From 67e47b6b23090bae4150b85e7e4a974be0b2896e Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 27 Sep 2023 08:47:52 +0200 Subject: [PATCH 23/34] fix --- examples/sac/config.yaml | 4 ++-- examples/sac/utils.py | 15 ++++----------- torchrl/objectives/sac.py | 2 +- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 11d3b014478..02fb8826ce0 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -5,7 +5,7 @@ env: exp_name: "Hopper-SAC" library: gym frame_skip: 1 - max_episode_steps: 1001 + max_episode_steps: 1000 seed: 1 # collector @@ -17,7 +17,7 @@ collector: init_env_steps: 1000 collector_device: cpu env_per_collector: 1 - reset_at_each_iter: true + reset_at_each_iter: False # replay buffer replay_buffer: diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 57f046ac3f1..2a30a8fbc50 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -252,7 +252,6 @@ def split_critic_params(critic_params): def make_sac_optimizer(cfg, loss_module): critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) - critic1_params, critic2_params = split_critic_params(critic_params) actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( @@ -261,23 +260,17 @@ def make_sac_optimizer(cfg, loss_module): weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, ) - optimizer_critic1 = optim.Adam( - critic1_params, - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - eps=cfg.optim.adam_eps, - ) - optimizer_critic2 = optim.Adam( - critic2_params, + optimizer_critic = optim.Adam( + critic_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, ) optimizer_alpha = optim.Adam( [loss_module.log_alpha], - lr=cfg.optim.lr, + lr=3.0e-4, ) - return optimizer_actor, optimizer_critic1, optimizer_critic2, optimizer_alpha + return optimizer_actor, optimizer_critic, optimizer_alpha # ==================================================================== diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 983a6b8c46f..aa5d804dfa0 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -750,7 +750,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) + alpha_loss = -self.log_alpha.exp() * (log_prob + self.target_entropy) else: # placeholder alpha_loss = torch.zeros_like(log_prob) From 5af2d9a77df164270ee4c26d4d21873c0088e122 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 27 Sep 2023 16:18:44 +0200 Subject: [PATCH 24/34] logexp test --- torchrl/objectives/sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index aa5d804dfa0..983a6b8c46f 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -750,7 +750,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha.exp() * (log_prob + self.target_entropy) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) else: # placeholder alpha_loss = torch.zeros_like(log_prob) From 7546aad19884a6af115d2e690be6a1035f80547e Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 28 Sep 2023 07:31:28 +0200 Subject: [PATCH 25/34] frameskip weight decay --- examples/sac/config.yaml | 14 +++++++------- examples/sac/utils.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 02fb8826ce0..361ef8b9c12 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,17 +1,17 @@ # environment and task env: - name: Hopper-v3 + name: Ant-v3 task: "" - exp_name: "Hopper-SAC" + exp_name: "Ant-SAC" library: gym frame_skip: 1 - max_episode_steps: 1000 + max_episode_steps: 5000 seed: 1 # collector collector: total_frames: 1_000_000 - init_random_frames: 25000 + init_random_frames: 10000 frames_per_batch: 1000 max_frames_per_traj: 1000 init_env_steps: 1000 @@ -29,11 +29,11 @@ optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 - lr: 1.0e-3 - weight_decay: 0.0 + lr: 3.0e-4 + weight_decay: 1.e-4 batch_size: 256 target_update_polyak: 0.995 - alpha_init: 0.2 + alpha_init: 1.0 adam_eps: 1.0e-8 # network diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 2a30a8fbc50..97e644a9390 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -32,7 +32,7 @@ def env_maker( return GymEnv( task, device=device, - frame_skip=frame_skip, + # frame_skip=frame_skip, from_pixels=from_pixels, max_episode_steps=max_episode_steps, ) From 06c2e6894c6b94cb7d650b49c6e1ab9a106855d6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 28 Sep 2023 15:48:14 +0200 Subject: [PATCH 26/34] fix frameskip, scratchdir buffer --- examples/sac/config.yaml | 2 +- examples/sac/sac.py | 1 + examples/sac/utils.py | 5 +---- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 361ef8b9c12..14e90c0acdc 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -4,7 +4,6 @@ env: task: "" exp_name: "Ant-SAC" library: gym - frame_skip: 1 max_episode_steps: 5000 seed: 1 @@ -23,6 +22,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay + scratch_dir: ant_seed3 # optim optim: diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 6e473f6e559..bd1d07302f8 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -69,6 +69,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir, device=device, ) diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 97e644a9390..acf7f3a9504 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -25,14 +25,11 @@ # ----------------- -def env_maker( - task, frame_skip=1, device="cpu", from_pixels=False, max_episode_steps=1000 -): +def env_maker(task, device="cpu", from_pixels=False, max_episode_steps=1000): with set_gym_backend("gym"): return GymEnv( task, device=device, - # frame_skip=frame_skip, from_pixels=from_pixels, max_episode_steps=max_episode_steps, ) From 25cf6643adbb3429cc89ccf92587d6dd5ede891f Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 28 Sep 2023 16:13:18 +0200 Subject: [PATCH 27/34] update config --- examples/sac/config.yaml | 4 ++-- examples/sac/sac.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 14e90c0acdc..4f41dcde127 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -2,7 +2,7 @@ env: name: Ant-v3 task: "" - exp_name: "Ant-SAC" + exp_name: ${env.name}_SAC library: gym max_episode_steps: 5000 seed: 1 @@ -22,7 +22,7 @@ collector: replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay - scratch_dir: ant_seed3 + scratch_dir: ${env.exp_name}_${env.seed} # optim optim: diff --git a/examples/sac/sac.py b/examples/sac/sac.py index bd1d07302f8..d5540d73ad1 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -93,8 +93,8 @@ def main(cfg: "DictConfig"): # noqa: F821 ) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip - eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.collector.max_frames_per_traj sampling_start = time.time() for i, tensordict in enumerate(collector): @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/training_time"] = training_time # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( From 9a7b0b4a9b6ee474e91fe50372e344ac60dc3d54 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 2 Oct 2023 11:55:11 +0200 Subject: [PATCH 28/34] undo stepcount --- examples/sac/config.yaml | 3 +-- examples/sac/sac.py | 2 +- examples/sac/utils.py | 11 +++++------ 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 4f41dcde127..1cdfd3c7991 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,6 +1,6 @@ # environment and task env: - name: Ant-v3 + name: HalfCheetah-v3 task: "" exp_name: ${env.name}_SAC library: gym @@ -12,7 +12,6 @@ collector: total_frames: 1_000_000 init_random_frames: 10000 frames_per_batch: 1000 - max_frames_per_traj: 1000 init_env_steps: 1000 collector_device: cpu env_per_collector: 1 diff --git a/examples/sac/sac.py b/examples/sac/sac.py index d5540d73ad1..c37c52c6bea 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -94,7 +94,7 @@ def main(cfg: "DictConfig"): # noqa: F821 prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - eval_rollout_steps = cfg.collector.max_frames_per_traj + eval_rollout_steps = cfg.env.max_episode_steps sampling_start = time.time() for i, tensordict in enumerate(collector): diff --git a/examples/sac/utils.py b/examples/sac/utils.py index acf7f3a9504..e95a70449d6 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -12,7 +12,7 @@ from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.transforms import RewardScaling, RewardSum +from torchrl.envs.transforms import InitTracker, RewardSum from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal @@ -25,21 +25,20 @@ # ----------------- -def env_maker(task, device="cpu", from_pixels=False, max_episode_steps=1000): +def env_maker(task, device="cpu", max_episode_steps=1000): with set_gym_backend("gym"): return GymEnv( task, device=device, - from_pixels=from_pixels, max_episode_steps=max_episode_steps, ) -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms(env): transformed_env = TransformedEnv( env, Compose( - RewardScaling(loc=0.0, scale=reward_scaling), + InitTracker(), DoubleToFloat("observation"), RewardSum(), ), @@ -87,7 +86,7 @@ def make_collector(cfg, train_env, actor_model_explore): actor_model_explore, init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, + max_frames_per_traj=cfg.env.max_episode_steps, total_frames=cfg.collector.total_frames, device=cfg.collector.collector_device, ) From f4f65a5cea961ef725d65476c8753451d32fe68f Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 3 Oct 2023 14:31:43 +0200 Subject: [PATCH 29/34] fix config --- examples/sac/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 3bc82ecb8d6..2d3425a2151 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -10,7 +10,7 @@ env: # collector collector: total_frames: 1_000_000 - init_random_frames: 10000 + init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 collector_device: cpu @@ -29,7 +29,7 @@ optim: gamma: 0.99 loss_function: l2 lr: 3.0e-4 - weight_decay: 1.e-4 + weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 alpha_init: 1.0 From b0a3799f56290311661f2704c69adf4215f58c11 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 09:14:58 -0400 Subject: [PATCH 30/34] amend --- examples/sac/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 7058b2488d0..ebbee32057b 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -39,7 +39,7 @@ def apply_env_transforms(env, max_episode_steps=1000): Compose( InitTracker(), StepCounter(max_episode_steps), - DoubleToFloat("observation"), + DoubleToFloat(), RewardSum(), ), ) From b758607b54a248508adad1e20b0c7c8566b9dcb8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 09:36:53 -0400 Subject: [PATCH 31/34] amend --- examples/sac/sac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sac/sac.py b/examples/sac/sac.py index c37c52c6bea..33b932ec42c 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -196,8 +196,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward metrics_to_log["eval/time"] = eval_time - - log_metrics(logger, metrics_to_log, collected_frames) + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) sampling_start = time.time() collector.shutdown() From f0482c7f6aebaeade552d676131ef35d9c8e1e38 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 15:14:19 +0100 Subject: [PATCH 32/34] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index aa85db9c69b..a5deb1283ea 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -226,17 +226,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.collector_device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ -<<<<<<< HEAD network.device=cuda:0 \ optimization.batch_size=10 \ optimization.utd_ratio=1 \ -======= ->>>>>>> 0272f031af1aa2ff2be400e283b79de55fcf9178 replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= -# record_video=True \ -# record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ total_frames=48 \ batch_size=10 \ From fbbc287d0fdcd921abaf9c1aada52e39bab145f6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 15:54:14 +0100 Subject: [PATCH 33/34] amend --- .../unittest/linux_examples/scripts/run_test.sh | 16 ++++++++-------- examples/ddpg/config.yaml | 2 +- examples/ddpg/ddpg.py | 4 ++-- examples/ddpg/utils.py | 12 ++++++------ examples/td3/config.yaml | 2 +- examples/td3/td3.py | 6 +++--- examples/td3/utils.py | 12 ++++++------ 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index a5deb1283ea..3229e08d891 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -66,13 +66,13 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_atari. python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optimization.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=4 \ collector.env_per_collector=2 \ collector.collector_device=cuda:0 \ network.device=cuda:0 \ - optimization.utd_ratio=1 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= @@ -145,7 +145,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optimization.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=4 \ collector.env_per_collector=2 \ @@ -182,13 +182,13 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optimization.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=2 \ collector.env_per_collector=1 \ collector.collector_device=cuda:0 \ network.device=cuda:0 \ - optimization.utd_ratio=1 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= @@ -227,8 +227,8 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ optim.batch_size=10 \ optim.utd_ratio=1 \ network.device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= @@ -245,7 +245,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ - optimization.batch_size=10 \ + optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.num_workers=2 \ collector.env_per_collector=1 \ diff --git a/examples/ddpg/config.yaml b/examples/ddpg/config.yaml index 464632f8bf3..0da2a021ed8 100644 --- a/examples/ddpg/config.yaml +++ b/examples/ddpg/config.yaml @@ -25,7 +25,7 @@ replay_buffer: prb: 0 # use prioritized experience replay # Optimization -optimization: +optim: utd_ratio: 1.0 gamma: 0.99 loss_function: smooth_l1 diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index b77494bc52f..2c15b93f162 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -60,7 +60,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Make Replay Buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, device=device, @@ -82,7 +82,7 @@ def main(cfg: "DictConfig"): # noqa: F821 num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb env_per_collector = cfg.collector.env_per_collector diff --git a/examples/ddpg/utils.py b/examples/ddpg/utils.py index ab4083fff28..e9f9f4ca30a 100644 --- a/examples/ddpg/utils.py +++ b/examples/ddpg/utils.py @@ -217,13 +217,13 @@ def make_loss_module(cfg, model): loss_module = DDPGLoss( actor_network=model[0], value_network=model[1], - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak + loss_module, eps=cfg.optim.target_update_polyak ) return loss_module, target_net_updater @@ -233,11 +233,11 @@ def make_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optimization.lr, weight_decay=cfg.optimization.weight_decay + actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay ) optimizer_critic = optim.Adam( critic_params, - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, ) return optimizer_actor, optimizer_critic diff --git a/examples/td3/config.yaml b/examples/td3/config.yaml index 35a2d9f8b2f..994f49963b1 100644 --- a/examples/td3/config.yaml +++ b/examples/td3/config.yaml @@ -25,7 +25,7 @@ replay_buffer: size: 1000000 # Optimization -optimization: +optim: utd_ratio: 1.0 gamma: 0.99 loss_function: l2 diff --git a/examples/td3/td3.py b/examples/td3/td3.py index f4d8707f404..358fdec816b 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -62,7 +62,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Make Replay Buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, device=device, @@ -84,9 +84,9 @@ def main(cfg: "DictConfig"): # noqa: F821 num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) - delayed_updates = cfg.optimization.policy_update_delay + delayed_updates = cfg.optim.policy_update_delay prb = cfg.replay_buffer.prb env_per_collector = cfg.collector.env_per_collector eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip diff --git a/examples/td3/utils.py b/examples/td3/utils.py index 9a8c5809f75..bab4ffe1f14 100644 --- a/examples/td3/utils.py +++ b/examples/td3/utils.py @@ -222,16 +222,16 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, delay_actor=True, delay_qvalue=True, action_spec=model[0][1].spec, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak + loss_module, eps=cfg.optim.target_update_polyak ) return loss_module, target_net_updater @@ -241,11 +241,11 @@ def make_optimizer(cfg, loss_module): actor_params = list(loss_module.actor_network_params.flatten_keys().values()) optimizer_actor = optim.Adam( - actor_params, lr=cfg.optimization.lr, weight_decay=cfg.optimization.weight_decay + actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay ) optimizer_critic = optim.Adam( critic_params, - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, ) return optimizer_actor, optimizer_critic From b5673fb1cb246751935c2b98a43045b52b8de21d Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 3 Oct 2023 16:06:25 +0100 Subject: [PATCH 34/34] empty