Skip to content

Commit

Permalink
[Algorithm] CrossQ (pytorch#2033)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
3 people authored Jul 10, 2024
1 parent 824f6d1 commit a0a47a9
Show file tree
Hide file tree
Showing 14 changed files with 2,163 additions and 2 deletions.
12 changes: 12 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device= \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device= \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.init_random_frames=10 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ REDQ

REDQLoss

CrossQ
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ

IQL
----

Expand Down
26 changes: 26 additions & 0 deletions sota-check/run_crossq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=crossq
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/crossq_%j.txt
#SBATCH --error=slurm_errors/crossq_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="crossq"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/crossq/crossq.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
fi
58 changes: 58 additions & 0 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# environment and task
env:
name: HalfCheetah-v4
task: ""
library: gym
max_episode_steps: 1000
seed: 42

# collector
collector:
total_frames: 1_000_000
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
env_per_collector: 1
reset_at_each_iter: False

# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: null

# optim
optim:
utd_ratio: 1.0
policy_update_delay: 3
gamma: 0.99
loss_function: l2
lr: 1.0e-3
weight_decay: 0.0
batch_size: 256
alpha_init: 1.0
adam_eps: 1.0e-8
beta1: 0.5
beta2: 0.999

# network
network:
batch_norm_momentum: 0.01
warmup_steps: 100000
critic_hidden_sizes: [2048, 2048]
actor_hidden_sizes: [256, 256]
critic_activation: relu
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"

# logging
logger:
backend: wandb
project_name: torchrl_example_crossQ
group_name: null
exp_name: ${env.name}_CrossQ
mode: online
eval_iter: 25000
229 changes: 229 additions & 0 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""CrossQ Example.
This is a simple self-contained example of a CrossQ training script.
It supports state environments like MuJoCo.
The helper functions are coded in the utils.py associated with this script.
"""
import time

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_crossQ_agent,
make_crossQ_optimizer,
make_environment,
make_loss_module,
make_replay_buffer,
)


@hydra.main(version_base="1.1", config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = cfg.network.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
device = torch.device(device)

# Create logger
exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="crossq_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg)

# Create agent
model, exploration_policy = make_crossQ_agent(cfg, train_env, device)

# Create CrossQ loss
loss_module = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
)

# Create optimizers
(
optimizer_actor,
optimizer_critic,
optimizer_alpha,
) = make_crossQ_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
update_counter = 0
delayed_updates = cfg.optim.policy_update_delay
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
alpha_losses,
q_losses,
) = ([], [], [])
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
q_loss = q_loss.mean()
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.detach().item())

if update_actor:
actor_loss, metadata_actor = loss_module.actor_loss(
sampled_tensordict
)
actor_loss = actor_loss.mean()
alpha_loss = loss_module.alpha_loss(
log_prob=metadata_actor["log_prob"]
).mean()

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

actor_losses.append(actor_loss.detach().item())
alpha_losses.append(alpha_loss.detach().item())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses).item()
metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item()
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading

0 comments on commit a0a47a9

Please sign in to comment.