Skip to content

Commit

Permalink
[Algorithm] Update TD3 Example (#1523)
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 authored Oct 3, 2023
1 parent 95b7206 commit df03cac
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 249 deletions.
4 changes: 2 additions & 2 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optimization.batch_size=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
Expand Down Expand Up @@ -247,7 +247,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optimization.batch_size=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
Expand Down
33 changes: 18 additions & 15 deletions examples/td3/config.yaml
Original file line number Diff line number Diff line change
@@ -1,47 +1,50 @@
# Environment
# task and env
env:
name: HalfCheetah-v3
task: ""
exp_name: "HalfCheetah-TD3"
library: gym
frame_skip: 1
exp_name: ${env.name}_TD3
library: gymnasium
seed: 42
max_episode_steps: 1000

# Collection
# collector
collector:
total_frames: 1000000
init_random_frames: 10000
init_random_frames: 25_000
init_env_steps: 1000
frames_per_batch: 1000
max_frames_per_traj: 1000
async_collection: 1
reset_at_each_iter: False
collector_device: cpu
env_per_collector: 1
num_workers: 1

# Replay Buffer
# replay buffer
replay_buffer:
prb: 0 # use prioritized experience replay
size: 1000000
scratch_dir: ${env.exp_name}_${env.seed}

# Optimization
optimization:
# optim
optim:
utd_ratio: 1.0
gamma: 0.99
loss_function: l2
lr: 3e-4
weight_decay: 2e-4
lr: 3.0e-4
weight_decay: 0.0
adam_eps: 1e-4
batch_size: 256
target_update_polyak: 0.995
policy_update_delay: 2
policy_noise: 0.2
noise_clip: 0.5

# Network
# network
network:
hidden_sizes: [256, 256]
activation: relu
device: "cuda:0"

# Logging
# logging
logger:
backend: wandb
mode: online
Expand Down
134 changes: 76 additions & 58 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
The helper functions are coded in the utils.py associated with this script.
"""

import hydra
import time

import hydra
import numpy as np
import torch
import torch.cuda
Expand All @@ -22,6 +23,7 @@

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_environment,
make_loss_module,
Expand All @@ -35,6 +37,7 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("TD3", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
Expand All @@ -45,140 +48,155 @@ def main(cfg: "DictConfig"): # noqa: F821
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create Environments
# Create environments
train_env, eval_env = make_environment(cfg)

# Create Agent
# Create agent
model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device)

# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Make Off-Policy Collector
# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)

# Make Replay Buffer
# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optimization.batch_size,
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir,
device=device,
)

# Make Optimizers
# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)

rewards = []
rewards_eval = []

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
r0 = None
q_loss = None

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optimization.utd_ratio
* cfg.optim.utd_ratio
)
delayed_updates = cfg.optimization.policy_update_delay
delayed_updates = cfg.optim.policy_update_delay
prb = cfg.replay_buffer.prb
env_per_collector = cfg.collector.env_per_collector
eval_rollout_steps = cfg.collector.max_frames_per_traj // cfg.env.frame_skip
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
frames_per_batch = cfg.collector.frames_per_batch
update_counter = 0

for i, tensordict in enumerate(collector):
sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy.step(tensordict.numel())
# update weights of the inference policy

# Update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for j in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

loss_td = loss_module(sampled_tensordict)
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
update_actor = j % delayed_updates == 0
q_loss.backward(retain_graph=update_actor)
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())

# Update actor
if update_actor:
actor_loss, *_ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

actor_losses.append(actor_loss.item())

# update qnet_target params
# Update target params
target_net_updater.step()

# update priority
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append(
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
}
if q_loss is not None:
train_log.update(
{
"actor_loss": np.mean(actor_losses),
"q_loss": np.mean(q_losses),
}
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if logger is not None:
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
if update_actor:
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
if logger is not None:
logger.log_scalar(
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
)
if len(rewards_eval):
pbar.set_description(
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
)
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit df03cac

Please sign in to comment.