Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implicit Q-Learning (IQL) #933

Merged
merged 19 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 327 additions & 0 deletions examples/iql/iql_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import hydra

import numpy as np
import torch
import torch.cuda
import tqdm

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer

from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import (
MLP,
NormalParamWrapper,
ProbabilisticActor,
SafeModule,
ValueOperator,
)
from torchrl.modules.distributions import TanhNormal

from torchrl.objectives import SoftUpdate
from torchrl.objectives.iql import IQLLoss
from torchrl.record.loggers import generate_exp_name, get_logger


def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False):
return GymEnv(
env_name, "run", device=device, frame_skip=frame_skip, from_pixels=from_pixels
)


def make_replay_buffer(
prb=False,
buffer_size=1000000,
buffer_scratch_dir="/tmp/",
device="cpu",
make_replay_buffer=3,
):
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha=0.7,
beta=0.5,
pin_memory=False,
prefetch=make_replay_buffer,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
)
else:
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=make_replay_buffer,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
device=device,
),
)
return replay_buffer


@hydra.main(version_base=None, config_path=".", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821

device = (
torch.device("cuda:0")
if torch.cuda.is_available()
and torch.cuda.device_count() > 0
and cfg.device == "cuda:0"
else torch.device("cpu")
)

exp_name = generate_exp_name("Online_IQL", cfg.exp_name)
logger = get_logger(
logger_type=cfg.logger, logger_name="iql_logging", experiment_name=exp_name
)

torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)

def env_factory(num_workers):
"""Creates an instance of the environment."""

# 1.2 Create env vector
vec_env = ParallelEnv(
create_env_fn=EnvCreator(lambda: env_maker(env_name=cfg.env_name)),
num_workers=num_workers,
)

return vec_env

# Sanity check
test_env = env_factory(num_workers=5)
num_actions = test_env.action_spec.shape[-1]

# Create Agent
# Define Actor Network
in_keys = ["observation"]
action_spec = test_env.action_spec
actor_net_kwargs = {
"num_cells": [256, 256],
"out_features": 2 * num_actions,
"activation_class": nn.ReLU,
}

actor_net = MLP(**actor_net_kwargs)

dist_class = TanhNormal
dist_kwargs = {
"min": action_spec.space.minimum[-1],
"max": action_spec.space.maximum[-1],
"tanh_loc": cfg.tanh_loc,
}

actor_net = NormalParamWrapper(
BY571 marked this conversation as resolved.
Show resolved Hide resolved
actor_net,
scale_mapping=f"biased_softplus_{cfg.default_policy_scale}",
scale_lb=cfg.scale_lb,
)
in_keys_actor = in_keys
actor_module = SafeModule(
BY571 marked this conversation as resolved.
Show resolved Hide resolved
actor_net,
in_keys=in_keys_actor,
out_keys=[
"loc",
"scale",
],
)
actor = ProbabilisticActor(
spec=action_spec,
in_keys=["loc", "scale"],
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
return_log_prob=False,
)

# Define Critic Network
qvalue_net_kwargs = {
"num_cells": [256, 256],
"out_features": 1,
"activation_class": nn.ReLU,
}

qvalue_net = MLP(
**qvalue_net_kwargs,
)

qvalue = ValueOperator(
in_keys=["action"] + in_keys,
module=qvalue_net,
)

# Define Value Network
value_net_kwargs = {
"num_cells": [256, 256],
"out_features": 1,
"activation_class": nn.ReLU,
}
value_net = MLP(**value_net_kwargs)
value = ValueOperator(
in_keys=in_keys,
module=value_net,
)

model = nn.ModuleList([actor, qvalue, value]).to(device)

# init nets
with torch.no_grad():
td = test_env.reset()
td = td.to(device)
actor(td)
qvalue(td)
value(td)

del td
test_env.close()
test_env.eval()

# Create IQL loss
loss_module = IQLLoss(
actor_network=model[0],
qvalue_network=model[1],
value_network=model[2],
num_qvalue_nets=2,
gamma=cfg.gamma,
temperature=cfg.temperature,
expectile=cfg.expectile,
loss_function="smooth_l1",
)

# Define Target Network Updater
target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak)

# Make Off-Policy Collector
collector = SyncDataCollector(
env_factory,
create_env_kwargs={"num_workers": cfg.env_per_collector},
policy=model[0],
frames_per_batch=cfg.frames_per_batch,
max_frames_per_traj=cfg.max_frames_per_traj,
total_frames=cfg.total_frames,
device=cfg.device,
passing_device=cfg.device,
)
collector.set_seed(cfg.seed)

# Make Replay Buffer
replay_buffer = make_replay_buffer(buffer_size=cfg.buffer_size, device=device)

# Optimizers
params = list(loss_module.parameters())
optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay)

rewards = []
rewards_eval = []

# Main loop
target_net_updater.init_()

collected_frames = 0

pbar = tqdm.tqdm(total=cfg.total_frames)
r0 = None
loss = None

for i, tensordict in enumerate(collector):

# update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

if "mask" in tensordict.keys():
# if multi-step, a mask is present to help filter padded values
current_frames = tensordict["mask"].sum()
tensordict = tensordict[tensordict.get("mask").squeeze(-1)]
else:
tensordict = tensordict.view(-1)
current_frames = tensordict.numel()
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

(
actor_losses,
q_losses,
value_losses,
) = ([], [], [])
# optimization steps
for _ in range(cfg.frames_per_batch * int(cfg.utd_ratio)):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample(cfg.batch_size).clone()

loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
value_loss = loss_td["loss_value"]

loss = actor_loss + q_loss + value_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())
value_losses.append(value_loss.item())

# update qnet_target params
target_net_updater.step()

# update priority
if cfg.prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append((i, tensordict["reward"].sum().item() / cfg.env_per_collector))
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
}
if q_loss is not None:
train_log.update(
{
"actor_loss": np.mean(actor_losses),
"q_loss": np.mean(q_losses),
"value_loss": np.mean(value_losses),
}
)
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

with set_exploration_mode("mean"), torch.no_grad():
eval_rollout = test_env.rollout(
max_steps=cfg.max_frames_per_traj,
policy=model[0],
auto_cast_to_device=True,
).clone()
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames)
if len(rewards_eval):
pbar.set_description(
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
)

collector.shutdown()


if __name__ == "__main__":
main()
50 changes: 50 additions & 0 deletions examples/iql/online_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
env_name: Pendulum-v1
env_library: gym
exp_name: "iql_pendulum"
seed: 42
async_collection: 1
record_video: 0
frame_skip: 1

total_frames: 1000000
init_env_steps: 10000
init_random_frames: 5000
# Updates
utd_ratio: 1.0
batch_size: 256
lr: 3e-4
weight_decay: 0.0
target_update_polyak: 0.995
multi_step: 1.0
gamma: 0.99

tanh_loc: False
default_policy_scale: 1.0
scale_lb: 0.1
activation: elu
from_pixels: 0
#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1]
collector_devices: [cpu]
env_per_collector: 5
frames_per_batch: 1000 # 5*200
max_frames_per_traj: 200
num_workers: 1

record_frames: 10000
loss_function: smooth_l1
batch_transform: 1
buffer_prefetch: 64
norm_stats: 1

device: "cuda:0"

# IQL hyperparameter
temperature: 3.0
expectile: 0.7

# Logging
logger: wandb

# Replay Buffer
prb: 0
buffer_size: 100000
Loading