Skip to content

Commit

Permalink
[Feature] Discrete SAC compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: ddc131acedbbe451b28758e757a8c240ebd72b80
Pull Request resolved: #2569
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent fbfe104 commit 9e2d214
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 104 deletions.
6 changes: 3 additions & 3 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def make_ppo_modules_pixels(proof_environment, device):
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
num_outputs = proof_environment.single_action_spec.space.n
if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox):
num_outputs = proof_environment.action_spec_unbatched.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.single_action_spec.shape
num_outputs = proof_environment.action_spec_unbatched.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.single_action_spec.shape[-1]
num_outputs = proof_environment.action_spec_unbatched.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down Expand Up @@ -82,7 +82,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.single_action_spec.shape[-1], device=device
proof_environment.action_spec_unbatched.shape[-1], device=device
),
)

Expand Down
1 change: 0 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def update(sampled_tensordict):
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
pbar.update(tensordict.numel())
# update weights of the inference policy
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):


def make_cql_modules_state(model_cfg, proof_environment):
action_spec = proof_environment.single_action_spec
action_spec = proof_environment.action_spec_unbatched

actor_net_kwargs = {
"num_cells": model_cfg.hidden_sizes,
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/discrete_sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ network:
activation: relu
device: null

compile:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
backend: wandb
Expand Down
170 changes: 90 additions & 80 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
The helper functions are coded in the utils.py associated with this script.
"""

from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger

from tensordict.nn import CudaGraphModule
from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand Down Expand Up @@ -75,9 +76,6 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, model[0])

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
Expand All @@ -91,9 +89,57 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer(
cfg, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
del optimizer_actor, optimizer_critic, optimizer_alpha

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

# Compute loss
loss_out = loss_module(sampled_tensordict)

actor_loss, q_loss, alpha_loss = (
loss_out["loss_actor"],
loss_out["loss_qvalue"],
loss_out["loss_alpha"],
)

# Update critic
(q_loss + actor_loss + alpha_loss).backward()
optimizer.step()

# Update target params
target_net_updater.step()

return loss_out.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Create off-policy collector
collector = make_collector(
cfg,
train_env,
model[0],
compile=compile_mode is not None,
compile_mode=compile_mode,
cudagraphs=cfg.compile.cudagraphs,
)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -108,129 +154,93 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

sampling_start = time.time()
for i, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
collected_data = next(c_iter)

# Update weights of the inference policy
collector.update_policy_weights_()
current_frames = collected_data.numel()

pbar.update(tensordict.numel())
pbar.update(current_frames)

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_data = collected_data.reshape(-1)
with timeit("rb - extend"):
# Add to replay buffer
replay_buffer.extend(collected_data)
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
alpha_losses,
) = ([], [], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
loss_out = loss_module(sampled_tensordict)

actor_loss, q_loss, alpha_loss = (
loss_out["loss_actor"],
loss_out["loss_qvalue"],
loss_out["loss_alpha"],
)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())
with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
sampled_tensordict = sampled_tensordict.to(device)
loss_out = update(sampled_tensordict).clone()

actor_losses.append(actor_loss.item())

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

alpha_losses.append(alpha_loss.item())

# Update target params
target_net_updater.step()
tds.append(loss_out)

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds).mean()

training_time = time.time() - training_start
# Logging
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
collected_data["next", "done"]
if collected_data["next", "done"].any()
else collected_data["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]
episode_rewards = collected_data["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
episode_length = collected_data["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/a_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

# Evaluation
prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter
cur_test_frame = (i * frames_per_batch) // eval_iter
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 50 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion sota-implementations/discrete_sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def make_environment(cfg, logger=None):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraphs=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
Expand All @@ -131,6 +138,8 @@ def make_collector(cfg, train_env, actor_model_explore):
reset_at_each_iter=cfg.collector.reset_at_each_iter,
device=device,
storing_device="cpu",
compile_policy=False if not compile else {"mode": compile_mode},
cudagraph_policy=cudagraphs,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down
14 changes: 7 additions & 7 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,12 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
spec=Composite(
**{
"loc": Unbounded(
proof_environment.single_action_spec.shape,
device=proof_environment.single_action_spec.device,
proof_environment.action_spec_unbatched.shape,
device=proof_environment.action_spec_unbatched.device,
),
"scale": Unbounded(
proof_environment.single_action_spec.shape,
device=proof_environment.single_action_spec.device,
proof_environment.action_spec_unbatched.shape,
device=proof_environment.action_spec_unbatched.device,
),
}
),
Expand All @@ -491,7 +491,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=Composite(**{action_key: proof_environment.single_action_spec}),
spec=Composite(**{action_key: proof_environment.action_spec_unbatched}),
),
)
return actor_simulator
Expand Down Expand Up @@ -532,10 +532,10 @@ def _dreamer_make_actor_real(
spec=Composite(
**{
"loc": Unbounded(
proof_environment.single_action_spec.shape,
proof_environment.action_spec_unbatched.shape,
),
"scale": Unbounded(
proof_environment.single_action_spec.shape,
proof_environment.action_spec_unbatched.shape,
),
}
),
Expand Down
Loading

0 comments on commit 9e2d214

Please sign in to comment.