From 4348c84b1d6bec2ef553bdb0ce816a45ba912d93 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 7 Aug 2024 00:29:48 +0200 Subject: [PATCH] [Algorithm] GAIL (#2273) Co-authored-by: Vincent Moens --- .../linux_examples/scripts/run_test.sh | 7 + docs/source/reference/objectives.rst | 9 + sota-implementations/gail/config.yaml | 46 +++ sota-implementations/gail/gail.py | 281 ++++++++++++++++++ sota-implementations/gail/gail_utils.py | 69 +++++ sota-implementations/gail/ppo_utils.py | 150 ++++++++++ test/test_cost.py | 222 ++++++++++++++ torchrl/objectives/__init__.py | 1 + torchrl/objectives/gail.py | 251 ++++++++++++++++ 9 files changed, 1036 insertions(+) create mode 100644 sota-implementations/gail/config.yaml create mode 100644 sota-implementations/gail/gail.py create mode 100644 sota-implementations/gail/gail_utils.py create mode 100644 sota-implementations/gail/ppo_utils.py create mode 100644 torchrl/objectives/gail.py diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index f8b700c0410..ef0d081f8fd 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -205,6 +205,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq env.train_num_envs=2 \ logger.mode=offline \ logger.backend= + python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \ + ppo.collector.total_frames=48 \ + replay_buffer.batch_size=16 \ + ppo.loss.mini_batch_size=10 \ + ppo.collector.frames_per_batch=16 \ + logger.mode=offline \ + logger.backend= # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 1d92c390a4e..db0c58409e2 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -179,6 +179,15 @@ CQL CQLLoss DiscreteCQLLoss +GAIL +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GAILLoss + DT ---- diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml new file mode 100644 index 00000000000..cf6c8053037 --- /dev/null +++ b/sota-implementations/gail/config.yaml @@ -0,0 +1,46 @@ +env: + env_name: HalfCheetah-v4 + seed: 42 + backend: gymnasium + +logger: + backend: wandb + project_name: gail + group_name: null + exp_name: gail_ppo + test_interval: 5000 + num_test_episodes: 5 + video: False + mode: online + +ppo: + collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + + optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: True + + loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + anneal_clip_epsilon: False + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 + +gail: + hidden_dim: 128 + lr: 3e-4 + use_grad_penalty: False + gp_lambda: 10.0 + device: null + +replay_buffer: + dataset: halfcheetah-expert-v2 + batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py new file mode 100644 index 00000000000..a3c64693fb3 --- /dev/null +++ b/sota-implementations/gail/gail.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""GAIL Example. + +This is a self-contained example of an offline GAIL training script. + +The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils. + +""" +import hydra +import numpy as np +import torch +import tqdm + +from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer +from ppo_utils import eval_model, make_env, make_ppo_models +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives.value.advantages import GAE +from torchrl.record import VideoRecorder +from torchrl.record.loggers import generate_exp_name, get_logger + + +@hydra.main(config_path="", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + + device = cfg.gail.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + num_mini_batches = ( + cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size + ) + total_network_updates = ( + (cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch) + * cfg.ppo.loss.ppo_epochs + * num_mini_batches + ) + + # Create logger + exp_name = generate_exp_name("Gail", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="gail_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.ppo.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.ppo.loss.gamma, + lmbda=cfg.ppo.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + + loss_module = ClipPPOLoss( + actor_network=actor, + critic_network=critic, + clip_epsilon=cfg.ppo.loss.clip_epsilon, + loss_critic_type=cfg.ppo.loss.loss_critic_type, + entropy_coef=cfg.ppo.loss.entropy_coef, + critic_coef=cfg.ppo.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create Discriminator + discriminator = make_gail_discriminator(cfg, collector.env, device) + + # Create loss + discriminator_loss = GAILLoss( + discriminator, + use_grad_penalty=cfg.gail.use_grad_penalty, + gp_lambda=cfg.gail.gp_lambda, + ) + + # Create optimizer + discriminator_optim = torch.optim.Adam( + params=discriminator.parameters(), lr=cfg.gail.lr + ) + + # Create test environment + logger_video = cfg.logger.video + test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) + test_env.eval() + + # Training loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + # Add collector data to expert data + expert_data.set( + discriminator_loss.tensor_keys.collector_action, + data["action"][: expert_data.batch_size[0]], + ) + expert_data.set( + discriminator_loss.tensor_keys.collector_observation, + data["observation"][: expert_data.batch_size[0]], + ) + d_loss = discriminator_loss(expert_data) + + # Backward pass + discriminator_optim.zero_grad() + d_loss.get("loss").backward() + discriminator_optim.step() + + # Compute discriminator reward + with torch.no_grad(): + data = discriminator(data) + d_rewards = -torch.log(1 - data["d_logits"] + 1e-8) + + # Set discriminator rewards to tensordict + data.set(("next", "reward"), d_rewards) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + # Update PPO + for _ in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for _, batch in enumerate(data_buffer): + + # Get a data batch + batch = batch.to(device) + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + for group in critic_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + log_info.update( + { + "train/actor_loss": actor_loss.item(), + "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"].item(), + "train/lr": alpha * cfg_optim_lr, + "train/clip_epsilon": ( + alpha * cfg_loss_clip_epsilon + if cfg_loss_anneal_clip_eps + else cfg_loss_clip_epsilon + ), + } + ) + + # evaluation + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( + i * frames_in_batch + ) // cfg_logger_test_interval: + actor.eval() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg_logger_num_test_episodes + ) + log_info.update( + { + "eval/reward": test_rewards.mean(), + } + ) + actor.train() + if logger is not None: + log_metrics(logger, log_info, i) + + pbar.close() + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py new file mode 100644 index 00000000000..067e9c8c927 --- /dev/null +++ b/sota-implementations/gail/gail_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.optim + +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import DoubleToFloat + +from torchrl.modules import SafeModule + + +# ==================================================================== +# Offline Replay buffer +# --------------------------- + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + dataset_id=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +def make_gail_discriminator(cfg, train_env, device="cpu"): + """Make GAIL discriminator.""" + + state_dim = train_env.observation_spec["observation"].shape[0] + action_dim = train_env.action_spec.shape[0] + + hidden_dim = cfg.gail.hidden_dim + + # Define Discriminator Network + class Discriminator(nn.Module): + def __init__(self, state_dim, action_dim): + super(Discriminator, self).__init__() + self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, 1) + + def forward(self, state, action): + x = torch.cat([state, action], dim=1) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + return torch.sigmoid(self.fc3(x)) + + d_module = SafeModule( + module=Discriminator(state_dim, action_dim), + in_keys=["observation", "action"], + out_keys=["d_logits"], + ) + return d_module.to(device) + + +def log_metrics(logger, metrics, step): + if logger is not None: + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py new file mode 100644 index 00000000000..7986738f8e6 --- /dev/null +++ b/sota-implementations/gail/ppo_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + StepCounter, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) + env = TransformedEnv(env) + env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "low": proof_environment.action_spec.space.low, + "high": proof_environment.action_spec.space.high, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], scale_lb=1e-8 + ), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards.append(reward.cpu()) + test_env.apply(dump_video) + del td_test + return torch.cat(test_rewards, 0).mean() diff --git a/test/test_cost.py b/test/test_cost.py index 871d9170aa1..6192e45c113 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -105,6 +105,7 @@ DreamerModelLoss, DreamerValueLoss, DTLoss, + GAILLoss, IQLLoss, KLPENPPOLoss, OnlineDTLoss, @@ -10459,6 +10460,227 @@ def test_dt_reduction(self, reduction): assert loss["loss"].shape == torch.Size([]) +class TestGAIL(LossModuleTestBase): + seed = 0 + + def _create_mock_discriminator( + self, batch=2, obs_dim=3, action_dim=4, device="cpu" + ): + # Discriminator + body = TensorDictModule( + MLP( + in_features=obs_dim + action_dim, + out_features=32, + depth=1, + num_cells=32, + activation_class=torch.nn.ReLU, + activate_last_layer=True, + ), + in_keys=["observation", "action"], + out_keys="hidden", + ) + head = TensorDictModule( + MLP( + in_features=32, + out_features=1, + depth=0, + num_cells=32, + activation_class=torch.nn.Sigmoid, + activate_last_layer=True, + ), + in_keys="hidden", + out_keys="d_logits", + ) + discriminator = TensorDictSequential(body, head) + + return discriminator.to(device) + + def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + "collector_action": action, + "collector_observation": obs, + }, + device=device, + ) + return td + + def _create_seq_mock_data_gail( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs, + "action": action, + "collector_action": action, + "collector_observation": obs, + }, + device=device, + ) + return td + + def test_gail_tensordict_keys(self): + discriminator = self._create_mock_discriminator() + loss_fn = GAILLoss(discriminator) + + default_keys = { + "expert_action": "action", + "expert_observation": "observation", + "collector_action": "collector_action", + "collector_observation": "collector_observation", + "discriminator_pred": "d_logits", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail_notensordict(self, device, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + discriminator = self._create_mock_discriminator(device=device) + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) + + tensordict = self._create_mock_data_gail(device=device) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(tensordict.flatten_keys("_").select(*in_keys)) + + loss_val_td = loss_fn(tensordict) + if use_grad_penalty: + loss_val, _ = loss_fn(**kwargs) + else: + loss_val = loss_fn(**kwargs) + + torch.testing.assert_close(loss_val_td.get("loss"), loss_val) + # test select + loss_fn.select_out_keys("loss") + if torch.__version__ >= "2.0.0": + loss_discriminator = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_discriminator = loss_fn(**kwargs) + return + assert loss_discriminator == loss_val_td["loss"] + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail(self, device, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + td = self._create_mock_data_gail(device=device) + + discriminator = self._create_mock_discriminator(device=device) + + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "discriminator" in name + if p.grad is None: + assert "discriminator" not in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_gail_state_dict(self, device): + torch.manual_seed(self.seed) + + discriminator = self._create_mock_discriminator(device=device) + + loss_fn = GAILLoss(discriminator) + sd = loss_fn.state_dict() + loss_fn2 = GAILLoss(discriminator) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_seq_gail(self, device, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_gail(device=device) + + discriminator = self._create_mock_discriminator(device=device) + + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "discriminator" in name + if p.grad is None: + assert "discriminator" not in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail_reduction(self, reduction, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_gail(device=device) + discriminator = self._create_mock_discriminator(device=device) + loss_fn = GAILLoss(discriminator, reduction=reduction) + loss = loss_fn(td) + if reduction == "none": + assert loss["loss"].shape == (td["observation"].shape[0], 1) + else: + assert loss["loss"].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index aa13a88c7e9..60701cb0121 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -11,6 +11,7 @@ from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss +from .gail import GAILLoss from .iql import DiscreteIQLLoss, IQLLoss from .multiagent import QMixerLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py new file mode 100644 index 00000000000..3c0050fca84 --- /dev/null +++ b/torchrl/objectives/gail.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +import torch.autograd as autograd +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey + +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import _reduce + + +class GAILLoss(LossModule): + r"""TorchRL implementation of the Generative Adversarial Imitation Learning (GAIL) loss. + + Presented in `"Generative Adversarial Imitation Learning" ` + + Args: + discriminator_network (TensorDictModule): stochastic actor + + Keyword Args: + use_grad_penalty (bool, optional): Whether to use gradient penalty. Default: ``False``. + gp_lambda (float, optional): Gradient penalty lambda. Default: ``10``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + expert_action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + expert_observation (NestedKey): The tensordict key where the observation is expected. + Defaults to ``"observation"``. + collector_action (NestedKey): The tensordict key where the collector action is expected. + Defaults to ``"collector_action"``. + collector_observation (NestedKey): The tensordict key where the collector observation is expected. + Defaults to ``"collector_observation"``. + discriminator_pred (NestedKey): The tensordict key where the discriminator prediction is expected. + """ + + expert_action: NestedKey = "action" + expert_observation: NestedKey = "observation" + collector_action: NestedKey = "collector_action" + collector_observation: NestedKey = "collector_observation" + discriminator_pred: NestedKey = "d_logits" + + default_keys = _AcceptedKeys() + + discriminator_network: TensorDictModule + discriminator_network_params: TensorDictParams + target_discriminator_network: TensorDictModule + target_discriminator_network_params: TensorDictParams + + out_keys = [ + "loss", + "gp_loss", + ] + + def __init__( + self, + discriminator_network: TensorDictModule, + *, + use_grad_penalty: bool = False, + gp_lambda: float = 10, + reduction: str = None, + ) -> None: + self._in_keys = None + self._out_keys = None + if reduction is None: + reduction = "mean" + super().__init__() + + # Discriminator Network + self.convert_to_functional( + discriminator_network, + "discriminator_network", + create_target_params=False, + ) + self.loss_function = torch.nn.BCELoss(reduction="none") + self.use_grad_penalty = use_grad_penalty + self.gp_lambda = gp_lambda + + self.reduction = reduction + + def _set_in_keys(self): + keys = self.discriminator_network.in_keys + keys = set(keys) + keys.add(self.tensor_keys.expert_observation) + keys.add(self.tensor_keys.expert_action) + keys.add(self.tensor_keys.collector_observation) + keys.add(self.tensor_keys.collector_action) + self._in_keys = sorted(keys, key=str) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss"] + if self.use_grad_penalty: + keys.append("gp_loss") + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + """The forward method. + + Computes the discriminator loss and gradient penalty if `use_grad_penalty` is set to True. If `use_grad_penalty` is set to True, the detached gradient penalty loss is also returned for logging purposes. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + device = self.discriminator_network.device + tensordict = tensordict.clone(False) + shape = tensordict.shape + if len(shape) > 1: + batch_size, seq_len = shape + else: + batch_size = shape[0] + collector_obs = tensordict.get(self.tensor_keys.collector_observation) + collector_act = tensordict.get(self.tensor_keys.collector_action) + + expert_obs = tensordict.get(self.tensor_keys.expert_observation) + expert_act = tensordict.get(self.tensor_keys.expert_action) + + combined_obs_inputs = torch.cat([expert_obs, collector_obs], dim=0) + combined_act_inputs = torch.cat([expert_act, collector_act], dim=0) + + combined_inputs = TensorDict( + { + self.tensor_keys.expert_observation: combined_obs_inputs, + self.tensor_keys.expert_action: combined_act_inputs, + }, + batch_size=[2 * batch_size], + device=device, + ) + + # create + if len(shape) > 1: + fake_labels = torch.zeros((batch_size, seq_len, 1), dtype=torch.float32).to( + device + ) + real_labels = torch.ones((batch_size, seq_len, 1), dtype=torch.float32).to( + device + ) + else: + fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(device) + real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device) + + with self.discriminator_network_params.to_module(self.discriminator_network): + d_logits = self.discriminator_network(combined_inputs).get( + self.tensor_keys.discriminator_pred + ) + + expert_preds, collection_preds = torch.split( + d_logits, [batch_size, batch_size], dim=0 + ) + + expert_loss = self.loss_function(expert_preds, real_labels) + collection_loss = self.loss_function(collection_preds, fake_labels) + + loss = expert_loss + collection_loss + out = {} + if self.use_grad_penalty: + obs = tensordict.get(self.tensor_keys.collector_observation) + acts = tensordict.get(self.tensor_keys.collector_action) + obs_e = tensordict.get(self.tensor_keys.expert_observation) + acts_e = tensordict.get(self.tensor_keys.expert_action) + + obss_noise = ( + torch.distributions.Uniform(0.0, 1.0).sample(obs_e.shape).to(device) + ) + acts_noise = ( + torch.distributions.Uniform(0.0, 1.0).sample(acts_e.shape).to(device) + ) + obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e + acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e + obss_mixture.requires_grad_(True) + acts_mixture.requires_grad_(True) + + pg_input_td = TensorDict( + { + self.tensor_keys.expert_observation: obss_mixture, + self.tensor_keys.expert_action: acts_mixture, + }, + [], + device=device, + ) + + with self.discriminator_network_params.to_module( + self.discriminator_network + ): + d_logits_mixture = self.discriminator_network(pg_input_td).get( + self.tensor_keys.discriminator_pred + ) + + gradients = torch.cat( + autograd.grad( + outputs=d_logits_mixture, + inputs=(obss_mixture, acts_mixture), + grad_outputs=torch.ones(d_logits_mixture.size(), device=device), + create_graph=True, + retain_graph=True, + only_inputs=True, + ), + dim=-1, + ) + + gp_loss = self.gp_lambda * torch.mean( + (torch.linalg.norm(gradients, dim=-1) - 1) ** 2 + ) + + loss += gp_loss + out["gp_loss"] = gp_loss.detach() + loss = _reduce(loss, reduction=self.reduction) + out["loss"] = loss + td_out = TensorDict(out, []) + return td_out