Skip to content

Commit

Permalink
[BugFix] Change ppo mujoco example to match paper results (pytorch#1714)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Nov 27, 2023
1 parent 38d9cb7 commit aedcf29
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
4 changes: 2 additions & 2 deletions examples/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(data.numel())

# Get training rewards and lengths
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "terminated"]]
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ logger:
optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: False
anneal_lr: True

# loss
loss:
Expand Down
14 changes: 8 additions & 6 deletions examples/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_mujoco import eval_model, make_env, make_ppo_models

# Define paper hyperparameters
device = "cpu" if not torch.cuda.device_count() else "cuda"
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
total_network_updates = (
Expand Down Expand Up @@ -67,6 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821
value_network=critic,
average_gae=False,
)

loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
Expand All @@ -78,8 +78,8 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr)
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5)

# Create logger
logger = None
Expand Down Expand Up @@ -120,9 +120,9 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(data.numel())

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "terminated"]]
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
Expand Down Expand Up @@ -187,7 +187,9 @@ def main(cfg: "DictConfig"): # noqa: F821
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
"train/clip_epsilon": alpha * cfg_loss_clip_epsilon
if cfg_loss_anneal_clip_eps
else cfg_loss_clip_epsilon,
}
)

Expand Down
8 changes: 5 additions & 3 deletions examples/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = GymEnv(env_name, device=device)
env = TransformedEnv(env)
env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(RewardSum())
env.append_transform(StepCounter())
env.append_transform(VecNorm(in_keys=["observation"]))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(DoubleToFloat(in_keys=["observation"]))
return env

Expand Down Expand Up @@ -72,7 +72,9 @@ def make_ppo_models_state(proof_environment):
# Add state-independent normal scale
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]),
AddStateIndependentNormalScale(
proof_environment.action_spec.shape[-1], scale_lb=1e-8
),
)

# Add probabilistic sampling of the actions
Expand Down

0 comments on commit aedcf29

Please sign in to comment.