Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] GAIL #2273

Merged
merged 35 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d7c5c4
fix norm
BY571 Jul 4, 2024
2d31f33
update docs
BY571 Jul 5, 2024
79bda13
update comments
BY571 Jul 5, 2024
1391b50
add sota-example-test
BY571 Jul 5, 2024
f444b72
update collection data slice
BY571 Jul 8, 2024
244b7ab
update docstring
BY571 Jul 8, 2024
db635d7
update config and objective with gp param
BY571 Jul 9, 2024
434622c
init cost tests gail
BY571 Jul 9, 2024
baca70f
update cost test
BY571 Jul 9, 2024
956567f
Merge branch 'main' into gail
BY571 Jul 11, 2024
8e7713f
add gail cost tests
BY571 Jul 11, 2024
714c35c
Merge remote-tracking branch 'origin/main' into gail
vmoens Jul 30, 2024
a05bef3
Merge branch 'main' into gail
BY571 Jul 31, 2024
b31da8a
Update config
BY571 Jul 31, 2024
63885b0
update gail device
BY571 Jul 31, 2024
739332c
update example tests
BY571 Jul 31, 2024
7a9919c
Merge branch 'gail' of github.com:BY571/rl into gail
BY571 Jul 31, 2024
3fd3c32
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 2, 2024
9455fef
gymnasium backend
BY571 Aug 5, 2024
fba43d2
Merge branch 'gail' of https://github.com/BY571/rl into gail
vmoens Aug 5, 2024
5e41d89
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
6c3f7d2
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
415443b
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 6, 2024
4926d80
fixes
vmoens Aug 6, 2024
cbd5dfa
init
vmoens Aug 6, 2024
70e1f49
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
b8ca705
amend
vmoens Aug 6, 2024
6a00bda
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
f0c225f
amend
vmoens Aug 6, 2024
511fa95
amend
vmoens Aug 6, 2024
4bc316b
amend
vmoens Aug 6, 2024
2f7e64c
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
3d43e42
amend
vmoens Aug 6, 2024
63398d1
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
c488bcd
amend
vmoens Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
env.train_num_envs=2 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \
collector.total_frames=48 \
loss.mini_batch_size=10 \
collector.frames_per_batch=16 \
logger.mode=offline \
logger.backend=

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ CQL
CQLLoss
DiscreteCQLLoss

GAIL
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

GAILLoss

DT
----

Expand Down
50 changes: 50 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
env:
env_name: HalfCheetah-v4
seed: 42
backend: gym


# logger
logger:
backend: wandb
project_name: gail
group_name: null
exp_name: gail_ppo
test_interval: 5000
num_test_episodes: 5
video: False
mode: online

ppo:
# collector
collector:
frames_per_batch: 2048
total_frames: 1_000_000

# Optim
optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: True

# loss
loss:
gamma: 0.99
mini_batch_size: 64
ppo_epochs: 10
gae_lambda: 0.95
clip_epsilon: 0.2
anneal_clip_epsilon: False
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2

gail:
hidden_dim: 128
lr: 3e-4
use_grad_penalty: False
gp_lambda: 10.0

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
275 changes: 275 additions & 0 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""GAIL Example.

This is a self-contained example of an offline GAIL training script.

The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils.

"""
import hydra
import numpy as np
import torch
import tqdm

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()

device = "cpu" if not torch.cuda.device_count() else "cuda"
BY571 marked this conversation as resolved.
Show resolved Hide resolved
num_mini_batches = (
cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size
)
total_network_updates = (
(cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch)
* cfg.ppo.loss.ppo_epochs
* num_mini_batches
)

# Create logger
exp_name = generate_exp_name("Gail", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="gail_logging",
experiment_name=exp_name,
wandb_kwargs={
"mode": cfg.logger.mode,
"config": dict(cfg),
"project": cfg.logger.project_name,
"group": cfg.logger.group_name,
},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.ppo.loss.mini_batch_size,
)

# Create loss and adv modules
adv_module = GAE(
gamma=cfg.ppo.loss.gamma,
lmbda=cfg.ppo.loss.gae_lambda,
value_network=critic,
average_gae=False,
)

loss_module = ClipPPOLoss(
actor_network=actor,
critic_network=critic,
clip_epsilon=cfg.ppo.loss.clip_epsilon,
loss_critic_type=cfg.ppo.loss.loss_critic_type,
entropy_coef=cfg.ppo.loss.entropy_coef,
critic_coef=cfg.ppo.loss.critic_coef,
normalize_advantage=True,
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)

# Create Discriminator
discriminator = make_gail_discriminator(cfg, collector.env, device)

# Create loss
discriminator_loss = GAILLoss(
discriminator,
use_grad_penalty=cfg.gail.use_grad_penalty,
gp_lambda=cfg.gail.gp_lambda,
)

# Create optimizer
discriminator_optim = torch.optim.Adam(
params=discriminator.parameters(), lr=cfg.gail.lr
)

# Create test environment
logger_video = cfg.logger.video
test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
if logger_video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
data["action"][: expert_data.batch_size[0]],
)
expert_data.set(
discriminator_loss.tensor_keys.collector_observation,
data["observation"][: expert_data.batch_size[0]],
)
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()

# Compute discriminator reward
with torch.no_grad():
data = discriminator(data)
d_rewards = -torch.log(1 - data["d_logits"] + 1e-8)

# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1

# Forward pass PPO loss
loss = loss_module(batch)
critic_loss = loss["loss_critic"]
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
if cfg_loss_anneal_clip_eps
else cfg_loss_clip_epsilon
),
}
)

# evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg_logger_test_interval:
actor.eval()
test_rewards = eval_model(
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
log_info.update(
{
"eval/reward": test_rewards.mean(),
}
)
actor.train()
if logger is not None:
log_metrics(logger, log_info, i)

pbar.close()


if __name__ == "__main__":
main()
Loading
Loading