From aedcf297fadca2108885ddcb4a10210af44bf086 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Mon, 27 Nov 2023 14:31:57 +0100 Subject: [PATCH] [BugFix] Change ppo mujoco example to match paper results (#1714) --- examples/a2c/a2c_mujoco.py | 4 ++-- examples/ppo/config_mujoco.yaml | 2 +- examples/ppo/ppo_mujoco.py | 14 ++++++++------ examples/ppo/utils_mujoco.py | 8 +++++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 48844dee6b6..7f9e588bbf6 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -101,9 +101,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and lengths - episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "terminated"]] + episode_length = data["next", "step_count"][data["next", "done"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml index 1272c1f4bff..0322526e7b1 100644 --- a/examples/ppo/config_mujoco.yaml +++ b/examples/ppo/config_mujoco.yaml @@ -18,7 +18,7 @@ logger: optim: lr: 3e-4 weight_decay: 0.0 - anneal_lr: False + anneal_lr: True # loss loss: diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index 988bc5300bf..52b12f688e1 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -28,7 +28,6 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models - # Define paper hyperparameters device = "cpu" if not torch.cuda.device_count() else "cuda" num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( @@ -67,6 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, ) + loss_module = ClipPPOLoss( actor=actor, critic=critic, @@ -78,8 +78,8 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5) # Create logger logger = None @@ -120,9 +120,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "terminated"]] + episode_length = data["next", "step_count"][data["next", "done"]] log_info.update( { "train/reward": episode_rewards.mean().item(), @@ -187,7 +187,9 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg_optim_lr, "train/sampling_time": sampling_time, "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg_loss_clip_epsilon, + "train/clip_epsilon": alpha * cfg_loss_clip_epsilon + if cfg_loss_anneal_clip_eps + else cfg_loss_clip_epsilon, } ) diff --git a/examples/ppo/utils_mujoco.py b/examples/ppo/utils_mujoco.py index 8fa2a53fd92..7be234b322d 100644 --- a/examples/ppo/utils_mujoco.py +++ b/examples/ppo/utils_mujoco.py @@ -28,10 +28,10 @@ def make_env(env_name="HalfCheetah-v4", device="cpu"): env = GymEnv(env_name, device=device) env = TransformedEnv(env) + env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) env.append_transform(RewardSum()) env.append_transform(StepCounter()) - env.append_transform(VecNorm(in_keys=["observation"])) - env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) env.append_transform(DoubleToFloat(in_keys=["observation"])) return env @@ -72,7 +72,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], scale_lb=1e-8 + ), ) # Add probabilistic sampling of the actions