Skip to content

Commit

Permalink
[Algorithm] Discrete IQL (pytorch#1793)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
BY571 and vmoens authored Jan 16, 2024
1 parent c11713a commit 69f3c4e
Show file tree
Hide file tree
Showing 11 changed files with 1,569 additions and 48 deletions.
9 changes: 9 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,15 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_iql.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ IQL
:template: rl_template_noinherit.rst

IQLLoss
DiscreteIQLLoss

CQL
----
Expand Down
195 changes: 195 additions & 0 deletions examples/iql/discrete_iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IQL Example.
This is a self-contained example of an online discrete IQL training script.
It works across Gym and MuJoCo over a variety of tasks.
The helper functions are coded in the utils.py associated with this script.
"""
import logging
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
log_metrics,
make_collector,
make_discrete_iql_model,
make_discrete_loss,
make_environment,
make_iql_optimizer,
make_replay_buffer,
)


@hydra.main(config_path=".", config_name="discrete_iql")
def main(cfg: "DictConfig"): # noqa: F821
# Create logger
exp_name = generate_exp_name("Discrete-IQL-online", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="iql_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)

# Create environments
train_env, eval_env = make_environment(
cfg,
cfg.env.train_num_envs,
cfg.env.eval_num_envs,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
device="cpu",
)

# Create model
model = make_discrete_iql_model(cfg, train_env, eval_env, device)

# Create collector
collector = make_collector(cfg, train_env, actor_model_explore=model[0])

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)

# Create optimizer
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)

# Main loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.collector.max_frames_per_traj
sampling_start = start_time = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())
# update weights of the inference policy
collector.update_policy_weights_()

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
for _ in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict
# compute losses
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

value_loss, _ = loss_module.value_loss(sampled_tensordict)
optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
target_net_updater.step()

# update priority
if prb:
sampled_tensordict.set(
loss_module.tensor_keys.priority,
metadata.pop("td_error").detach().max(0).values,
)
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_rewards = tensordict["next", "episode_reward"][
tensordict["next", "done"]
]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = q_loss.detach()
metrics_to_log["train/actor_loss"] = actor_loss.detach()
metrics_to_log["train/value_loss"] = value_loss.detach()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions examples/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# task and env
env:
name: CartPole-v1
task: ""
exp_name: iql_${env.name}
n_samples_stats: 1000
seed: 0
train_num_envs: 1
eval_num_envs: 1
backend: gym


# collector
collector:
frames_per_batch: 200
total_frames: 20000
init_random_frames: 1000
env_per_collector: 1
device: cpu
max_frames_per_traj: 200

# logger
logger:
backend: wandb
log_interval: 5000 # record interval in frames
eval_steps: 200
mode: online
eval_iter: 1000

# replay buffer
replay_buffer:
prb: 0
buffer_prefetch: 64
size: 1_000_000

# optimization
optim:
utd_ratio: 1
device: cuda:0
lr: 3e-4
weight_decay: 0.0
batch_size: 256

# network
model:
hidden_sizes: [256, 256]
activation: relu


# loss
loss:
loss_function: l2
gamma: 0.99
hard_update_interval: 10

# IQL specific hyperparameter
temperature: 100
expectile: 0.8
40 changes: 26 additions & 14 deletions examples/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
@set_gym_backend("gym")
@hydra.main(config_path=".", config_name="offline_config")
def main(cfg: "DictConfig"): # noqa: F821

# Create logger
exp_name = generate_exp_name("IQL-offline", cfg.env.exp_name)
logger = None
Expand Down Expand Up @@ -64,7 +63,9 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_module, target_net_updater = make_loss(cfg.loss, model)

# Create optimizer
optimizer = make_iql_optimizer(cfg.optim, loss_module)
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

Expand All @@ -78,18 +79,29 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(i)
# sample data
data = replay_buffer.sample()
# compute loss
loss_vals = loss_module(data.clone().to(device))

actor_loss = loss_vals["loss_actor"]
q_loss = loss_vals["loss_qvalue"]
value_loss = loss_vals["loss_value"]
loss_val = actor_loss + q_loss + value_loss

# update model
optimizer.zero_grad()
loss_val.backward()
optimizer.step()

if data.device != device:
data = data.to(device, non_blocking=True)

# compute losses
loss_info = loss_module(data)
actor_loss = loss_info["loss_actor"]
value_loss = loss_info["loss_value"]
q_loss = loss_info["loss_qvalue"]

optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
target_net_updater.step()

# log metrics
Expand Down
Loading

0 comments on commit 69f3c4e

Please sign in to comment.