Skip to content

Commit

Permalink
[Algorithm] TD3+BC (pytorch#2249)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
3 people authored Jul 9, 2024
1 parent 2555437 commit 3566388
Show file tree
Hide file tree
Showing 11 changed files with 1,793 additions and 10 deletions.
4 changes: 3 additions & 1 deletion .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \
optim.gradient_steps=55 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr
- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py)
- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py)
- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py)
- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py)
- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py)
- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py)
- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py)
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ TD3

TD3Loss

TD3+BC
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TD3BCLoss

PPO
---

Expand Down
26 changes: 26 additions & 0 deletions sota-check/run_td3bc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=td3bc_offline
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/td3bc_offline_%j.txt
#SBATCH --error=slurm_errors/td3bc_offline_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="td3bc_offline"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log
fi
1 change: 1 addition & 0 deletions sota-check/submitit-release-check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ scripts=(
run_ppo_mujoco.sh
run_sac.sh
run_td3.sh
run_td3bc.sh
run_dt.sh
run_dt_online.sh
)
Expand Down
45 changes: 45 additions & 0 deletions sota-implementations/td3_bc/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# task and env
env:
name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency
task: ""
library: gymnasium
seed: 42
max_episode_steps: 1000

# replay buffer
replay_buffer:
dataset: halfcheetah-medium-v2
batch_size: 256

# optim
optim:
gradient_steps: 100000
gamma: 0.99
loss_function: l2
lr: 3.0e-4
weight_decay: 0.0
adam_eps: 1e-4
batch_size: 256
target_update_polyak: 0.995
policy_update_delay: 2
policy_noise: 0.2
noise_clip: 0.5
alpha: 2.5

# network
network:
hidden_sizes: [256, 256]
activation: relu
device: null

# logging
logger:
backend: wandb
project_name: td3+bc_${replay_buffer.dataset}
group_name: null
exp_name: TD3+BC_${replay_buffer.dataset}
mode: online
eval_iter: 5000
eval_steps: 1000
eval_envs: 1
video: False
146 changes: 146 additions & 0 deletions sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""TD3+BC Example.
This is a self-contained example of an offline RL TD3+BC training script.
The helper functions are coded in the utils.py associated with this script.
"""
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
dump_video,
log_metrics,
make_environment,
make_loss_module,
make_offline_replay_buffer,
make_optimizer,
make_td3_agent,
)


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.library).set()

# Create logger
exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="td3bc_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = cfg.network.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)

# Creante env
eval_env = make_environment(
cfg,
logger=logger,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create agent
model, _ = make_td3_agent(cfg, eval_env, device)

# Create loss
loss_module, target_net_updater = make_loss_module(cfg.optim, model)

# Create optimizer
optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
update_counter = 0
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
start_time = time.time()
for i in pbar:
pbar.update(1)
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_loss.item()

to_log = {"q_loss": q_loss.item()}

# Update actor
if update_actor:
actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update target params
target_net_updater.step()

to_log["actor_loss"] = actor_loss.item()
to_log.update(actorloss_metadata)

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
main()
Loading

0 comments on commit 3566388

Please sign in to comment.