Skip to content

Commit

Permalink
[Algorithm] Simpler IQL example (pytorch#998)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
BY571 and vmoens authored Dec 14, 2023
1 parent 0906206 commit bc4a72f
Show file tree
Hide file tree
Showing 19 changed files with 1,115 additions and 598 deletions.
43 changes: 33 additions & 10 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=

# ==================================================================================== #
# ================================ Gymnasium ========================================= #
Expand Down Expand Up @@ -115,7 +121,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_c
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
optim.optim_steps_per_batch=1 \
replay_buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
Expand Down Expand Up @@ -174,11 +179,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
collector.total_frames=48 \
buffer.batch_size=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
collector.device=cuda:0 \
optim.device=cuda:0 \
logger.mode=offline \
logger.backend=

Expand Down Expand Up @@ -248,12 +262,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
env.train_num_envs=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
buffer.batch_size=10 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
collector.total_frames=48 \
Expand Down
112 changes: 78 additions & 34 deletions examples/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"""

import time

import hydra
import numpy as np
import torch
Expand All @@ -18,16 +20,18 @@
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
log_metrics,
make_continuous_cql_optimizer,
make_continuous_loss,
make_cql_model,
make_cql_optimizer,
make_environment,
make_loss,
make_offline_replay_buffer,
)


@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
# Create logger
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
Expand All @@ -37,49 +41,96 @@ def main(cfg: "DictConfig"): # noqa: F821
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)

# Make Env
# Create env
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)

# Make Buffer
# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Make Model
# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)

# Make Loss
loss_module, target_net_updater = make_loss(cfg.loss, model)
# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)

# Make Optimizer
optimizer = make_cql_optimizer(cfg.optim, loss_module)
# Create Optimizer
(
policy_optim,
critic_optim,
alpha_optim,
alpha_prime_optim,
) = make_continuous_cql_optimizer(cfg, loss_module)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

r0 = None
l0 = None

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(i)
# sample data
data = replay_buffer.sample()
# loss
loss_vals = loss_module(data)
# backprop
actor_loss = loss_vals["loss_actor"]
# compute loss
loss_vals = loss_module(data.clone().to(device))

# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
if i >= policy_eval_start:
actor_loss = loss_vals["loss_actor"]
else:
actor_loss = loss_vals["loss_actor_bc"]
q_loss = loss_vals["loss_qvalue"]
value_loss = loss_vals["loss_value"]
loss_val = actor_loss + q_loss + value_loss
cql_loss = loss_vals["loss_cql"]

q_loss = q_loss + cql_loss

alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]

# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]

alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()

optimizer.zero_grad()
loss_val.backward()
optimizer.step()
policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_optim.step()

critic_optim.zero_grad()
# TODO: we have the option to compute losses independently retain is not needed?
q_loss.backward(retain_graph=False)
critic_optim.step()

loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss

# log metrics
to_log = {
"loss": loss.item(),
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
"loss_actor": loss_vals["loss_actor"].item(),
"loss_qvalue": q_loss.item(),
"loss_cql": cql_loss.item(),
"loss_alpha": alpha_loss.item(),
"loss_alpha_prime": alpha_prime_loss.item(),
}

# update qnet_target params
target_net_updater.step()

# evaluation
Expand All @@ -88,20 +139,13 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

if r0 is None:
r0 = eval_td["next", "reward"].sum(1).mean().item()
if l0 is None:
l0 = loss_val.item()

for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), i)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
logger.log_scalar("evaluation_reward", eval_reward, i)
log_metrics(logger, to_log, i)

pbar.set_description(
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), evaluation_reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
)
pbar.close()
print(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit bc4a72f

Please sign in to comment.