From f5a187d7d4fb26fd4d57bfcbca6f2a9dfa3c91ad Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 14:01:47 -0800 Subject: [PATCH] [Feature] PPO compatibility with compile ghstack-source-id: 0ed29f352fcd85f0dc0683d90e95bdbecf6c14f9 Pull Request resolved: https://github.com/pytorch/rl/pull/2652 --- sota-implementations/dqn/dqn_atari.py | 2 + sota-implementations/dqn/dqn_cartpole.py | 2 + sota-implementations/ppo/config_atari.yaml | 6 + sota-implementations/ppo/config_mujoco.yaml | 6 + sota-implementations/ppo/ppo_atari.py | 187 +++++++++++------- sota-implementations/ppo/ppo_mujoco.py | 202 ++++++++++++-------- sota-implementations/ppo/utils_atari.py | 29 +-- sota-implementations/ppo/utils_mujoco.py | 18 +- torchrl/_utils.py | 6 +- torchrl/collectors/collectors.py | 6 +- 10 files changed, 288 insertions(+), 176 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 255b6b2ee65..d43ac25c822 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -28,6 +28,8 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_dqn_model, make_env +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 89a1e04d586..873cf278d4b 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -23,6 +23,8 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils_cartpole import eval_model, make_dqn_model, make_env +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 diff --git a/sota-implementations/ppo/config_atari.yaml b/sota-implementations/ppo/config_atari.yaml index 31e6f13a58c..f7a340e3512 100644 --- a/sota-implementations/ppo/config_atari.yaml +++ b/sota-implementations/ppo/config_atari.yaml @@ -25,6 +25,7 @@ optim: weight_decay: 0.0 max_grad_norm: 0.5 anneal_lr: True + device: # loss loss: @@ -37,3 +38,8 @@ loss: critic_coef: 1.0 entropy_coef: 0.01 loss_critic_type: l2 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/ppo/config_mujoco.yaml b/sota-implementations/ppo/config_mujoco.yaml index 2dd3c6cc229..822aea89616 100644 --- a/sota-implementations/ppo/config_mujoco.yaml +++ b/sota-implementations/ppo/config_mujoco.yaml @@ -22,6 +22,7 @@ optim: lr: 3e-4 weight_decay: 0.0 anneal_lr: True + device: # loss loss: @@ -34,3 +35,8 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 7878a0286e3..cc42ef38f9d 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -9,30 +9,42 @@ """ from __future__ import annotations +import warnings + import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder + +from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time - import torch.optim import tqdm from tensordict import TensorDict + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + torch.set_float32_matmul_precision("high") + + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -41,9 +53,17 @@ def main(cfg: "DictConfig"): # noqa: F821 mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create models (check utils_atari.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) + actor, critic = make_ppo_models(cfg.env.env_name, device=device) # Create collector collector = SyncDataCollector( @@ -51,17 +71,22 @@ def main(cfg: "DictConfig"): # noqa: F821 policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, - device="cpu", - storing_device="cpu", + device=device, + storing_device=device, max_frames_per_traj=-1, + compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage( + frames_per_batch, compilable=cfg.compile.compile, device=device + ), sampler=sampler, batch_size=mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -70,6 +95,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + device=device, + vectorized=not cfg.compile.compile, ) loss_module = ClipPPOLoss( actor_network=actor, @@ -121,15 +148,52 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 - num_network_updates = 0 - start_time = time.time() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = ( (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches ) - sampling_start = time.time() + def update(batch, num_network_updates): + optim.zero_grad(set_to_none=True) + + # Linearly decrease the learning rate and clip epsilon + alpha = torch.ones((), device=device) + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates = num_network_updates + 1 + # Get a data batch + batch = batch.to(device, non_blocking=True) + + # Forward pass PPO loss + loss = loss_module(batch) + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=cfg_optim_max_grad_norm + ) + + # Update the networks + optim.step() + return loss.detach().set("alpha", alpha), num_network_updates + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs @@ -142,13 +206,16 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.loss.clip_epsilon = cfg_loss_clip_epsilon losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) - for i, data in enumerate(collector): + collector_iter = iter(collector) + + for i in range(len(collector)): + with timeit("collecting"): + data = next(collector_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) + pbar.update(frames_in_batch) # Get training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] @@ -162,96 +229,70 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - training_start = time.time() - for j in range(cfg_loss_ppo_epochs): - - # Compute GAE - with torch.no_grad(): - data = adv_module(data.to(device, non_blocking=True)) - data_reshape = data.reshape(-1) - # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg_optim_anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - if cfg_loss_anneal_clip_eps: - loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) - num_network_updates += 1 - # Get a data batch - batch = batch.to(device, non_blocking=True) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg_optim_max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() + with timeit("training"): + for j in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(), timeit("adv"): + torch.compiler.cudagraph_mark_step_begin() + data = adv_module(data) + if compile_mode: + data = data.clone() + with timeit("rb - extend"): + # Update the data buffer + data_reshape = data.reshape(-1) + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + torch.compiler.cudagraph_mark_step_begin() + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) + loss = loss.clone() + num_network_updates = num_network_updates.clone() + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ) # Get training losses and times - training_time = time.time() - training_start losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg_optim_lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg_loss_clip_epsilon, + "train/lr": loss["alpha"] * cfg_optim_lr, + "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon, } ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards.mean(), - "eval/time": eval_time, } ) actor.train() - if logger: + log_info.update(timeit.todict(prefix="time")) for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index c1d6fe52585..a0cf2726aca 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -9,30 +9,43 @@ """ from __future__ import annotations +import warnings + import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder + +from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time - import torch.optim import tqdm from tensordict import TensorDict + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss + from torchrl.objectives import ClipPPOLoss, group_optimizers from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + torch.set_float32_matmul_precision("high") + + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( (cfg.collector.total_frames // cfg.collector.frames_per_batch) @@ -40,9 +53,17 @@ def main(cfg: "DictConfig"): # noqa: F821 * num_mini_batches ) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) + actor, critic = make_ppo_models(cfg.env.env_name, device=device) # Create collector collector = SyncDataCollector( @@ -53,14 +74,21 @@ def main(cfg: "DictConfig"): # noqa: F821 device=device, storing_device=device, max_frames_per_traj=-1, + compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage( + cfg.collector.frames_per_batch, + compilable=cfg.compile.compile, + device=device, + ), sampler=sampler, batch_size=cfg.loss.mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -69,6 +97,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + device=device, + vectorized=not cfg.compile.compile, ) loss_module = ClipPPOLoss( @@ -82,8 +112,14 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5) + actor_optim = torch.optim.Adam( + actor.parameters(), lr=torch.tensor(cfg.optim.lr, device=device), eps=1e-5 + ) + critic_optim = torch.optim.Adam( + critic.parameters(), lr=torch.tensor(cfg.optim.lr, device=device), eps=1e-5 + ) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim # Create logger logger = None @@ -111,31 +147,68 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + def update(batch, num_network_updates): + optim.zero_grad(set_to_none=True) + # Linearly decrease the learning rate and clip epsilon + alpha = torch.ones((), device=device) + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates = num_network_updates + 1 + + # Forward pass PPO loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + total_loss = critic_loss + actor_loss + + # Backward pass + total_loss.backward() + + # Update the networks + optim.step() + return loss.detach().set("alpha", alpha), num_network_updates + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) + # Main loop collected_frames = 0 - num_network_updates = 0 - start_time = time.time() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) pbar = tqdm.tqdm(total=cfg.collector.total_frames) - sampling_start = time.time() - # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs cfg_optim_anneal_lr = cfg.optim.anneal_lr - cfg_optim_lr = cfg.optim.lr + cfg_optim_lr = torch.tensor(cfg.optim.lr, device=device) cfg_loss_anneal_clip_eps = cfg.loss.anneal_clip_epsilon cfg_loss_clip_epsilon = cfg.loss.clip_epsilon cfg_logger_test_interval = cfg.logger.test_interval cfg_logger_num_test_episodes = cfg.logger.num_test_episodes losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches]) - for i, data in enumerate(collector): + collector_iter = iter(collector) + + for i in range(len(collector)): + with timeit("collecting"): + data = next(collector_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch - pbar.update(data.numel()) + pbar.update(frames_in_batch) # Get training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] @@ -149,100 +222,73 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - training_start = time.time() - for j in range(cfg_loss_ppo_epochs): - - # Compute GAE - with torch.no_grad(): - data = adv_module(data) - data_reshape = data.reshape(-1) - - # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg_optim_anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - for group in critic_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - if cfg_loss_anneal_clip_eps: - loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) - num_network_updates += 1 - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + with timeit("training"): + for j in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(), timeit("adv"): + torch.compiler.cudagraph_mark_step_begin() + data = adv_module(data) + if compile_mode: + data = data.clone() + + with timeit("rb - extend"): + # Update the data buffer + data_reshape = data.reshape(-1) + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + torch.compiler.cudagraph_mark_step_begin() + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) + loss = loss.clone() + num_network_updates = num_network_updates.clone() + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ) # Get training losses and times - training_time = time.time() - training_start losses_mean = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses_mean.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg_optim_lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg_loss_clip_epsilon + "train/lr": loss["alpha"] * cfg_optim_lr, + "train/clip_epsilon": loss["alpha"] * cfg_loss_clip_epsilon if cfg_loss_anneal_clip_eps else cfg_loss_clip_epsilon, } ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg_logger_num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards.mean(), - "eval/time": eval_time, } ) actor.train() if logger: + log_info.update(timeit.todict(prefix="time")) for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 040259377ad..fa9d4bb053e 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -31,7 +31,6 @@ ActorValueOperator, ConvNet, MLP, - OneHotCategorical, ProbabilisticActor, TanhNormal, ValueOperator, @@ -51,6 +50,7 @@ def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False from_pixels=True, pixels_only=False, device="cpu", + categorical_action_encoding=True, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) @@ -86,7 +86,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -94,14 +94,14 @@ def make_ppo_modules_pixels(proof_environment): # Define distribution class and kwargs if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox): num_outputs = proof_environment.action_spec_unbatched.space.n - distribution_class = OneHotCategorical + distribution_class = torch.distributions.Categorical distribution_kwargs = {} else: # is ContinuousBox num_outputs = proof_environment.action_spec_unbatched.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), } # Define input keys @@ -113,14 +113,16 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.full_action_spec_unbatched, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +174,12 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): - proof_environment = make_parallel_env(env_name, 1, device="cpu") + proof_environment = make_parallel_env(env_name, 1, device=device) common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, + device=device, ) # Wrap modules in a single ActorCritic operator @@ -185,8 +190,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(10) + actor_critic(td) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index f2e08ffb129..1f224b81528 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, } @@ -63,6 +63,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -76,7 +77,7 @@ def make_ppo_models_state(proof_environment): policy_mlp, AddStateIndependentNormalScale( proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 - ), + ).to(device), ) # Add probabilistic sampling of the actions @@ -87,7 +88,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.full_action_spec_unbatched, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -100,6 +101,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -117,9 +119,9 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): - proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) +def make_ppo_models(env_name, device): + proof_environment = make_env(env_name, device=device) + actor, critic = make_ppo_models_state(proof_environment, device=device) return actor, critic diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 45f8c433725..cc1621d8723 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -854,7 +854,7 @@ def set_mode(self, type: Any | None) -> None: @wraps(torch.compile) -def compile_with_warmup(*args, warmup: int, **kwargs): +def compile_with_warmup(*args, warmup: int = 1, **kwargs): """Compile a model with warm-up. This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase, @@ -863,7 +863,7 @@ def compile_with_warmup(*args, warmup: int, **kwargs): Args: *args: Arguments to be passed to `torch.compile`. - warmup (int): Number of calls to the model before compiling it. + warmup (int): Number of calls to the model before compiling it. Defaults to 1. **kwargs: Keyword arguments to be passed to `torch.compile`. Returns: @@ -888,7 +888,7 @@ def compile_with_warmup(*args, warmup: int, **kwargs): if model is None: return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) else: - count = 0 + count = -1 compiled_model = model @wraps(model) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 78c470e8b7a..f2709411e3b 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -47,6 +47,7 @@ _ProcessNoWarn, _replace_last, accept_remote_rref_udf_invocation, + compile_with_warmup, logger as torchrl_logger, prod, RL_WARNINGS, @@ -67,7 +68,6 @@ set_exploration_type, ) - try: from torch.compiler import cudagraph_mark_step_begin except ImportError: @@ -661,7 +661,9 @@ def __init__( self.policy_weights = TensorDict() if self.compiled_policy: - self.policy = torch.compile(self.policy, **self.compiled_policy_kwargs) + self.policy = compile_with_warmup( + self.policy, **self.compiled_policy_kwargs + ) if self.cudagraphed_policy: self.policy = CudaGraphModule(self.policy, **self.cudagraphed_policy_kwargs)