Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] CrossQ #2033

Merged
merged 49 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
0a23ae8
add crossQ examples
BY571 Mar 20, 2024
9bdee71
add loss
BY571 Mar 20, 2024
570a20e
Update naming experiment
BY571 Mar 21, 2024
5086249
update
BY571 Mar 21, 2024
c3a927f
update add tests
BY571 Mar 21, 2024
d1c9c34
detach
BY571 Mar 21, 2024
e879b7c
update tests
BY571 Mar 21, 2024
75255e7
update run_test.sh
BY571 Mar 21, 2024
a7b79c3
move crossq to sota-implementations
BY571 Mar 21, 2024
be84f3f
update loss
BY571 Mar 26, 2024
2170ad8
update cat prediction
BY571 Mar 26, 2024
75d4cee
Merge branch 'main' into crossQ
vmoens Jun 12, 2024
7711a4e
Merge branch 'main' into crossQ
BY571 Jun 26, 2024
f0ac167
add batchrenorm to crossq
BY571 Jun 26, 2024
37abb14
Merge branch 'crossQ' of github.com:BY571/rl into crossQ
BY571 Jun 26, 2024
bc7675a
small fixes
BY571 Jun 26, 2024
9543f2e
update docs and sota checks
BY571 Jun 26, 2024
53e35f7
hyperparam fix
BY571 Jun 26, 2024
172e1c0
test
BY571 Jun 27, 2024
fdb7e8b
update batch norm tests
BY571 Jun 27, 2024
5501d43
tests
BY571 Jul 3, 2024
c47ac84
cleanup
BY571 Jul 5, 2024
e718c3f
Merge branch 'main' into crossQ
BY571 Jul 5, 2024
f94165e
update
BY571 Jul 7, 2024
02c94ff
update lr param
BY571 Jul 8, 2024
93b6a7b
Merge branch 'crossQ' of https://github.com/BY571/rl into crossQ
BY571 Jul 8, 2024
4b914e6
Apply suggestions from code review
vmoens Jul 8, 2024
af8c64a
Merge remote-tracking branch 'origin/main' into crossQ
vmoens Jul 8, 2024
845c8a9
Merge branch 'crossQ' of https://github.com/BY571/rl into crossQ
vmoens Jul 8, 2024
7b4a69d
set qnet eval in actor loss
BY571 Jul 8, 2024
77de044
Merge branch 'crossQ' of https://github.com/BY571/rl into crossQ
BY571 Jul 8, 2024
35c7a98
take off comment
BY571 Jul 8, 2024
68a1a9f
amend
vmoens Jul 8, 2024
c04eb3b
Merge branch 'crossQ' of https://github.com/BY571/rl into crossQ
vmoens Jul 8, 2024
12672ee
Merge remote-tracking branch 'origin/main' into crossQ
vmoens Jul 8, 2024
7fbb27d
amend
vmoens Jul 8, 2024
ff80481
amend
vmoens Jul 8, 2024
caf702e
amend
vmoens Jul 8, 2024
70e2882
amend
vmoens Jul 8, 2024
ccd1b7f
amend
vmoens Jul 8, 2024
d3c8b0e
Merge remote-tracking branch 'origin/main' into crossQ
vmoens Jul 9, 2024
d3e0bb1
Apply suggestions from code review
vmoens Jul 9, 2024
349cb28
amend
vmoens Jul 9, 2024
75a43e7
amend
vmoens Jul 9, 2024
abada6c
fix device error
BY571 Jul 9, 2024
c878b81
Update objective delay actor
BY571 Jul 9, 2024
f222b11
Update tests not expecting target update
BY571 Jul 9, 2024
067b560
update example utils
BY571 Jul 9, 2024
c010e39
amend
vmoens Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di
replay_buffer.size=120 \
env.name=CartPole-v1 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device= \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device= \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.init_random_frames=10 \
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ REDQ

REDQLoss

CrossQ
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ

IQL
----

Expand Down
26 changes: 26 additions & 0 deletions sota-check/run_crossq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

#SBATCH --job-name=crossq
#SBATCH --ntasks=32
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/crossq_%j.txt
#SBATCH --error=slurm_errors/crossq_%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="torchrl-example-check-$current_commit"
group_name="crossq"
export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/sota-implementations/crossq/crossq.py \
logger.backend=wandb \
logger.project_name="$project_name" \
logger.group_name="$group_name"

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
fi
98 changes: 98 additions & 0 deletions sota-implementations/crossq/batchrenorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
vmoens marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn


class BatchRenorm(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this in the modules no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and add it to the doc.
Happy to write a couple of tests.
Is it a copy paste? If so, can we check the license?

"""
BatchRenorm Module (https://arxiv.org/abs/1702.03275).

BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm,
BatchRenorm utilizes running statistics to normalize batches after an initial warmup phase.
This approach reduces the impact of "outlier" batches that may occur during extended training periods,
making BatchRenorm more robust for long training runs.

During the warmup phase, BatchRenorm functions identically to a BatchNorm layer.

Args:
num_features (int): Number of features in the input tensor.

Keyword Args:
momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01.
eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5.
max_r (float, optional): Maximum value for the scaling factor r. Default is 3.0.
max_d (float, optional): Maximum value for the bias factor d. Default is 5.0.
warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 10000.
"""

def __init__(
self,
num_features,
momentum=0.01,
eps=1e-5,
max_r=3.0,
max_d=5.0,
warmup_steps=10000,
):
super(BatchRenorm, self).__init__()
vmoens marked this conversation as resolved.
Show resolved Hide resolved
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.max_r = max_r
self.max_d = max_d
self.warmup_steps = warmup_steps

self.register_buffer(
"running_mean", torch.zeros(num_features, dtype=torch.float32)
)
self.register_buffer(
"running_var", torch.ones(num_features, dtype=torch.float32)
)
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64))
self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32))

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() >= 2
view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2)
# _v = lambda v: v.view(view_dims)

def _v(v):
return v.view(view_dims)

running_std = (self.running_var + self.eps).sqrt_()

if self.training:
reduce_dims = [i for i in range(x.dim()) if i != 1]
b_mean = x.mean(reduce_dims)
b_var = x.var(reduce_dims, unbiased=False)
b_std = (b_var + self.eps).sqrt_()

r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r)
d = torch.clamp(
(b_mean.detach() - self.running_mean) / running_std,
-self.max_d,
self.max_d,
)

# Compute warmup factor (0 during warmup, 1 after warmup)
warmup_factor = torch.clamp(
self.num_batches_tracked / self.warmup_steps, 0.0, 1.0
)
r = 1.0 + (r - 1.0) * warmup_factor
d = d * warmup_factor

x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d)

unbiased_var = b_var.detach() * x.shape[1] / (x.shape[1] - 1)
self.running_var += self.momentum * (unbiased_var - self.running_var)
self.running_mean += self.momentum * (b_mean.detach() - self.running_mean)
self.num_batches_tracked += 1
else:
x = (x - _v(self.running_mean)) / _v(running_std)

x = _v(self.weight) * x + _v(self.bias)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ optim:
policy_update_delay: 3
gamma: 0.99
loss_function: l2
lr: 3.0e-4
lr: 1.0e-3
weight_decay: 0.0
batch_size: 256
alpha_init: 1.0
# Adam β1 = 0.5
adam_eps: 1.0e-8
beta1: 0.5
beta2: 0.999

# network
network:
batch_norm_momentum: 0.01
# warmup_steps: 100000 # 10^5
warmup_steps: 100000
critic_hidden_sizes: [2048, 2048]
actor_hidden_sizes: [256, 256]
critic_activation: tanh
critic_activation: relu
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
Expand Down
41 changes: 20 additions & 21 deletions examples/crossQ/crossQ.py → sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@
from utils import (
log_metrics,
make_collector,
make_crossQ_agent,
make_crossQ_optimizer,
make_environment,
make_loss_module,
make_replay_buffer,
make_sac_agent,
make_sac_optimizer,
)


@hydra.main(version_base="1.1", config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)
BY571 marked this conversation as resolved.
Show resolved Hide resolved
if device is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create logger
exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
Expand All @@ -60,9 +62,9 @@ def main(cfg: "DictConfig"): # noqa: F821
train_env, eval_env = make_environment(cfg)

# Create agent
model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)
model, exploration_policy = make_crossQ_agent(cfg, train_env, device)

# Create SAC loss
# Create CrossQ loss
loss_module = make_loss_module(cfg, model)

# Create off-policy collector
Expand All @@ -82,7 +84,7 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_actor,
optimizer_critic,
optimizer_alpha,
) = make_sac_optimizer(cfg, loss_module)
) = make_crossQ_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
Expand Down Expand Up @@ -133,41 +135,40 @@ def main(cfg: "DictConfig"): # noqa: F821
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict)

q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
q_loss = q_loss.mean()
# Update critic
optimizer_critic.zero_grad()
q_loss.mean().backward()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.mean().detach().item())
q_losses.append(q_loss.detach().item())

if update_actor:
actor_loss, metadata_actor = loss_module._actor_loss(
actor_loss, metadata_actor = loss_module.actor_loss(
sampled_tensordict
)
alpha_loss = loss_module._alpha_loss(
actor_loss = actor_loss.mean()
alpha_loss = loss_module.alpha_loss(
log_prob=metadata_actor["log_prob"]
vmoens marked this conversation as resolved.
Show resolved Hide resolved
)
).mean()

# Update actor
optimizer_actor.zero_grad()
actor_loss.mean().backward()
actor_loss.backward()
optimizer_actor.step()

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.mean().backward()
alpha_loss.backward()
optimizer_alpha.step()

actor_losses.append(actor_loss.mean().detach().item())
alpha_losses.append(alpha_loss.mean().detach().item())
actor_losses.append(actor_loss.detach().item())
alpha_losses.append(alpha_loss.detach().item())

# Update priority
if prb:
Expand All @@ -193,8 +194,6 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log["train/q_loss"] = np.mean(q_losses).item()
metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item()
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item()
# metrics_to_log["train/alpha"] = loss_td["alpha"].item()
# metrics_to_log["train/entropy"] = loss_td["entropy"].item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

Expand Down
Loading