Skip to content

Commit

Permalink
[Algorithm] GAIL (pytorch#2273)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
  • Loading branch information
BY571 and vmoens authored Aug 6, 2024
1 parent afe8596 commit 4348c84
Show file tree
Hide file tree
Showing 9 changed files with 1,036 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
env.train_num_envs=2 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \
ppo.collector.total_frames=48 \
replay_buffer.batch_size=16 \
ppo.loss.mini_batch_size=10 \
ppo.collector.frames_per_batch=16 \
logger.mode=offline \
logger.backend=

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ CQL
CQLLoss
DiscreteCQLLoss

GAIL
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

GAILLoss

DT
----

Expand Down
46 changes: 46 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
env:
env_name: HalfCheetah-v4
seed: 42
backend: gymnasium

logger:
backend: wandb
project_name: gail
group_name: null
exp_name: gail_ppo
test_interval: 5000
num_test_episodes: 5
video: False
mode: online

ppo:
collector:
frames_per_batch: 2048
total_frames: 1_000_000

optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: True

loss:
gamma: 0.99
mini_batch_size: 64
ppo_epochs: 10
gae_lambda: 0.95
clip_epsilon: 0.2
anneal_clip_epsilon: False
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2

gail:
hidden_dim: 128
lr: 3e-4
use_grad_penalty: False
gp_lambda: 10.0
device: null

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
281 changes: 281 additions & 0 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""GAIL Example.
This is a self-contained example of an offline GAIL training script.
The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils.
"""
import hydra
import numpy as np
import torch
import tqdm

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()

device = cfg.gail.device
if device in ("", None):
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
device = torch.device(device)
num_mini_batches = (
cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size
)
total_network_updates = (
(cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch)
* cfg.ppo.loss.ppo_epochs
* num_mini_batches
)

# Create logger
exp_name = generate_exp_name("Gail", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="gail_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.ppo.loss.mini_batch_size,
)

# Create loss and adv modules
adv_module = GAE(
gamma=cfg.ppo.loss.gamma,
lmbda=cfg.ppo.loss.gae_lambda,
value_network=critic,
average_gae=False,
)

loss_module = ClipPPOLoss(
actor_network=actor,
critic_network=critic,
clip_epsilon=cfg.ppo.loss.clip_epsilon,
loss_critic_type=cfg.ppo.loss.loss_critic_type,
entropy_coef=cfg.ppo.loss.entropy_coef,
critic_coef=cfg.ppo.loss.critic_coef,
normalize_advantage=True,
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create Discriminator
discriminator = make_gail_discriminator(cfg, collector.env, device)

# Create loss
discriminator_loss = GAILLoss(
discriminator,
use_grad_penalty=cfg.gail.use_grad_penalty,
gp_lambda=cfg.gail.gp_lambda,
)

# Create optimizer
discriminator_optim = torch.optim.Adam(
params=discriminator.parameters(), lr=cfg.gail.lr
)

# Create test environment
logger_video = cfg.logger.video
test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
if logger_video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
data["action"][: expert_data.batch_size[0]],
)
expert_data.set(
discriminator_loss.tensor_keys.collector_observation,
data["observation"][: expert_data.batch_size[0]],
)
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()

# Compute discriminator reward
with torch.no_grad():
data = discriminator(data)
d_rewards = -torch.log(1 - data["d_logits"] + 1e-8)

# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1

# Forward pass PPO loss
loss = loss_module(batch)
critic_loss = loss["loss_critic"]
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
if cfg_loss_anneal_clip_eps
else cfg_loss_clip_epsilon
),
}
)

# evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg_logger_test_interval:
actor.eval()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
log_info.update(
{
"eval/reward": test_rewards.mean(),
}
)
actor.train()
if logger is not None:
log_metrics(logger, log_info, i)

pbar.close()


if __name__ == "__main__":
main()
Loading

0 comments on commit 4348c84

Please sign in to comment.