Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Discrete IQL #1793

Merged
merged 10 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,15 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/discrete_iql.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ IQL
:template: rl_template_noinherit.rst

IQLLoss
DiscreteIQLLoss
vmoens marked this conversation as resolved.
Show resolved Hide resolved

CQL
----
Expand Down
195 changes: 195 additions & 0 deletions examples/iql/discrete_iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IQL Example.

This is a self-contained example of an online discrete IQL training script.

It works across Gym and MuJoCo over a variety of tasks.

The helper functions are coded in the utils.py associated with this script.

"""
import logging
import time

import hydra
import numpy as np
import torch
import tqdm
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
log_metrics,
make_collector,
make_discrete_iql_model,
make_discrete_loss,
make_environment,
make_iql_optimizer,
make_replay_buffer,
)


@hydra.main(config_path=".", config_name="discrete_iql")
def main(cfg: "DictConfig"): # noqa: F821
# Create logger
exp_name = generate_exp_name("Discrete-IQL-online", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="iql_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)
device = torch.device(cfg.optim.device)

# Create environments
train_env, eval_env = make_environment(
cfg,
cfg.env.train_num_envs,
cfg.env.eval_num_envs,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
device="cpu",
)

# Create model
model = make_discrete_iql_model(cfg, train_env, eval_env, device)

# Create collector
collector = make_collector(cfg, train_env, actor_model_explore=model[0])

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)

# Create optimizer
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)

# Main loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.collector.max_frames_per_traj
sampling_start = start_time = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())
# update weights of the inference policy
collector.update_policy_weights_()

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
for _ in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict
# compute losses
actor_loss, _ = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

value_loss, _ = loss_module.value_loss(sampled_tensordict)
optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
target_net_updater.step()

# update priority
if prb:
sampled_tensordict.set(
loss_module.tensor_keys.priority,
metadata.pop("td_error").detach().max(0).values,
)
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_rewards = tensordict["next", "episode_reward"][
tensordict["next", "done"]
]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = q_loss.detach()
metrics_to_log["train/actor_loss"] = actor_loss.detach()
metrics_to_log["train/value_loss"] = value_loss.detach()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
logging.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions examples/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# task and env
env:
name: CartPole-v1
task: ""
exp_name: iql_${env.name}
n_samples_stats: 1000
seed: 0
train_num_envs: 1
eval_num_envs: 1
backend: gym


# collector
collector:
frames_per_batch: 200
total_frames: 20000
init_random_frames: 1000
env_per_collector: 1
device: cpu
max_frames_per_traj: 200

# logger
logger:
backend: wandb
log_interval: 5000 # record interval in frames
eval_steps: 200
mode: online
eval_iter: 1000

# replay buffer
replay_buffer:
prb: 0
buffer_prefetch: 64
size: 1_000_000

# optimization
optim:
utd_ratio: 1
device: cuda:0
lr: 3e-4
weight_decay: 0.0
batch_size: 256

# network
model:
hidden_sizes: [256, 256]
activation: relu


# loss
loss:
loss_function: l2
gamma: 0.99
hard_update_interval: 10

# IQL specific hyperparameter
temperature: 100
expectile: 0.8
40 changes: 26 additions & 14 deletions examples/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
@set_gym_backend("gym")
@hydra.main(config_path=".", config_name="offline_config")
def main(cfg: "DictConfig"): # noqa: F821

# Create logger
exp_name = generate_exp_name("IQL-offline", cfg.env.exp_name)
logger = None
Expand Down Expand Up @@ -64,7 +63,9 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_module, target_net_updater = make_loss(cfg.loss, model)

# Create optimizer
optimizer = make_iql_optimizer(cfg.optim, loss_module)
optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer(
cfg.optim, loss_module
)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

Expand All @@ -78,18 +79,29 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(i)
# sample data
data = replay_buffer.sample()
# compute loss
loss_vals = loss_module(data.clone().to(device))

actor_loss = loss_vals["loss_actor"]
q_loss = loss_vals["loss_qvalue"]
value_loss = loss_vals["loss_value"]
loss_val = actor_loss + q_loss + value_loss

# update model
optimizer.zero_grad()
loss_val.backward()
optimizer.step()

if data.device != device:
data = data.to(device, non_blocking=True)

# compute losses
loss_info = loss_module(data)
actor_loss = loss_info["loss_actor"]
value_loss = loss_info["loss_value"]
q_loss = loss_info["loss_qvalue"]

optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()

optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
target_net_updater.step()

# log metrics
Expand Down
Loading
Loading