From f149811da93991bfb798fc3cfa6641b38e321cfd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 17:23:37 -0800 Subject: [PATCH] [Feature] DQN compatibility with compile ghstack-source-id: 113dc8c4a5562d217ed867ace1942b2f6b8a39f9 Pull Request resolved: https://github.com/pytorch/rl/pull/2571 --- sota-implementations/dqn/config_atari.yaml | 5 + sota-implementations/dqn/config_cartpole.yaml | 5 + sota-implementations/dqn/dqn_atari.py | 123 ++++++++++-------- sota-implementations/dqn/dqn_cartpole.py | 110 ++++++++++------ sota-implementations/dqn/utils_atari.py | 15 ++- sota-implementations/dqn/utils_cartpole.py | 11 +- torchrl/data/tensor_specs.py | 19 ++- .../modules/tensordict_module/exploration.py | 20 ++- 8 files changed, 193 insertions(+), 115 deletions(-) diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 50e374cef14..021e7fd6132 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -39,3 +39,8 @@ loss: gamma: 0.99 hard_update_freq: 10_000 num_updates: 1 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 9a69762d6bd..58be7fb3bb5 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -38,3 +38,8 @@ loss: gamma: 0.99 hard_update_freq: 50 num_updates: 1 + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 4f37502ab76..255b6b2ee65 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -10,14 +10,14 @@ from __future__ import annotations import tempfile -import time +import warnings import hydra import torch.nn import torch.optim import tqdm -from tensordict.nn import TensorDictSequential -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule, TensorDictSequential +from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -48,28 +48,17 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Make the components - model = make_dqn_model(cfg.env.env_name, frame_skip) + model = make_dqn_model(cfg.env.env_name, frame_skip, device=device) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, eps_end=cfg.collector.eps_end, spec=model.spec, + device=device, ) model_explore = TensorDictSequential( model, greedy_module, - ).to(device) - - # Create the collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, frame_skip, device), - policy=model_explore, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - init_random_frames=init_random_frames, ) # Create the replay buffer @@ -129,25 +118,70 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + def update(sampled_tensordict): + loss_td = loss_module(sampled_tensordict) + q_loss = loss_td["loss"] + optimizer.zero_grad() + q_loss.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad + ) + optimizer.step() + target_net_updater.step() + return q_loss.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create the collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, frame_skip, device), + policy=model_explore, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + init_random_frames=init_random_frames, + compile_policy={"mode": compile_mode, "fullgraph": True} + if compile_mode is not None + else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + # Main loop collected_frames = 0 - start_time = time.time() - sampling_start = time.time() num_updates = cfg.loss.num_updates max_grad = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes q_losses = torch.zeros(num_updates, device=device) pbar = tqdm.tqdm(total=total_frames) - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() * frame_skip collected_frames += current_frames greedy_module.step(current_frames) - replay_buffer.extend(data) + with timeit("rb - extend"): + replay_buffer.extend(data) # Get and log training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] @@ -169,57 +203,47 @@ def main(cfg: "DictConfig"): # noqa: F821 continue # optimization steps - training_start = time.time() for j in range(num_updates): - - sampled_tensordict = replay_buffer.sample() - sampled_tensordict = sampled_tensordict.to(device) - - loss_td = loss_module(sampled_tensordict) - q_loss = loss_td["loss"] - optimizer.zero_grad() - q_loss.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=max_grad - ) - optimizer.step() - target_net_updater.step() - q_losses[j].copy_(q_loss.detach()) - - training_time = time.time() - training_start + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + q_loss = update(sampled_tensordict) + q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( { - "train/q_values": (data["action_value"] * data["action"]).sum().item() - / frames_per_batch, - "train/q_loss": q_losses.mean().item(), + "train/q_values": data["chosen_action_value"].sum() / frames_per_batch, + "train/q_loss": q_losses.mean(), "train/epsilon": greedy_module.eps, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() - eval_start = time.time() test_rewards = eval_model( model, test_env, num_episodes=num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards, - "eval/eval_time": eval_time, } ) model.train() + if i % 200 == 0: + timeit.print() + log_info.update(timeit.todict(prefix="time")) + timeit.erase() + # Log all the information if logger: for key, value in log_info.items(): @@ -227,16 +251,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # update weights of the inference policy collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index b97d8c904fd..89a1e04d586 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -2,17 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + from __future__ import annotations -import time +import warnings import hydra import torch.nn import torch.optim import tqdm -from tensordict.nn import TensorDictSequential -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule, TensorDictSequential +from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type @@ -35,29 +36,18 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(device) # Make the components - model = make_dqn_model(cfg.env.env_name) + model = make_dqn_model(cfg.env.env_name, device=device) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, eps_end=cfg.collector.eps_end, spec=model.spec, + device=device, ) model_explore = TensorDictSequential( model, greedy_module, - ).to(device) - - # Create the collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, "cpu"), - policy=model_explore, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device="cpu", - storing_device="cpu", - max_frames_per_traj=-1, - init_random_frames=cfg.collector.init_random_frames, ) # Create the replay buffer @@ -111,9 +101,49 @@ def main(cfg: "DictConfig"): # noqa: F821 ), ) + def update(sampled_tensordict): + loss_td = loss_module(sampled_tensordict) + q_loss = loss_td["loss"] + optimizer.zero_grad() + q_loss.backward() + optimizer.step() + target_net_updater.step() + return q_loss.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create the collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, "cpu"), + policy=model_explore, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device="cpu", + storing_device="cpu", + max_frames_per_traj=-1, + init_random_frames=cfg.collector.init_random_frames, + compile_policy={"mode": compile_mode, "fullgraph": True} + if compile_mode is not None + else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + # Main loop collected_frames = 0 - start_time = time.time() num_updates = cfg.loss.num_updates batch_size = cfg.buffer.batch_size test_interval = cfg.logger.test_interval @@ -121,17 +151,20 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - sampling_start = time.time() q_losses = torch.zeros(num_updates, device=device) - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() - replay_buffer.extend(data) + + with timeit("rb - extend"): + replay_buffer.extend(data) collected_frames += current_frames greedy_module.step(current_frames) @@ -156,18 +189,13 @@ def main(cfg: "DictConfig"): # noqa: F821 continue # optimization steps - training_start = time.time() for j in range(num_updates): - sampled_tensordict = replay_buffer.sample(batch_size) - sampled_tensordict = sampled_tensordict.to(device) - loss_td = loss_module(sampled_tensordict) - q_loss = loss_td["loss"] - optimizer.zero_grad() - q_loss.backward() - optimizer.step() - target_net_updater.step() - q_losses[j].copy_(q_loss.detach()) - training_time = time.time() - training_start + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample(batch_size) + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + q_loss = update(sampled_tensordict) + q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( @@ -176,29 +204,31 @@ def main(cfg: "DictConfig"): # noqa: F821 / frames_per_batch, "train/q_loss": q_losses.mean().item(), "train/epsilon": greedy_module.eps, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() - eval_start = time.time() test_rewards = eval_model(model, test_env, num_test_episodes) - eval_time = time.time() - eval_start model.train() log_info.update( { "eval/reward": test_rewards, - "eval/eval_time": eval_time, } ) + if i % 200 == 0: + timeit.print() + log_info.update(timeit.todict(prefix="time")) + timeit.erase() + # Log all the information if logger: for key, value in log_info.items(): @@ -206,14 +236,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # update weights of the inference policy collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 1e5440a54b6..0956dfeb2ac 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -39,6 +39,7 @@ def make_env(env_name, frame_skip, device, is_test=False): from_pixels=True, pixels_only=False, device=device, + categorical_action_encoding=True, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) @@ -61,7 +62,7 @@ def make_env(env_name, frame_skip, device, is_test=False): # -------------------------------------------------------------------- -def make_dqn_modules_pixels(proof_environment): +def make_dqn_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -75,25 +76,27 @@ def make_dqn_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - cnn_output = cnn(torch.ones(input_shape)) + cnn_output = cnn(torch.ones(input_shape, device=device)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=torch.nn.ReLU, out_features=num_actions, num_cells=[512], + device=device, ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), - spec=Composite(action=action_spec), + spec=Composite(action=action_spec).to(device), in_keys=["pixels"], ) return qvalue_module -def make_dqn_model(env_name, frame_skip): - proof_environment = make_env(env_name, frame_skip, device="cpu") - qvalue_module = make_dqn_modules_pixels(proof_environment) +def make_dqn_model(env_name, frame_skip, device): + proof_environment = make_env(env_name, frame_skip, device=device) + qvalue_module = make_dqn_modules_pixels(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index d378f1ec76b..c49ff15f5fc 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -31,7 +31,7 @@ def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False): # -------------------------------------------------------------------- -def make_dqn_modules(proof_environment): +def make_dqn_modules(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -45,19 +45,20 @@ def make_dqn_modules(proof_environment): activation_class=torch.nn.ReLU, out_features=num_outputs, num_cells=[120, 84], + device=device, ) qvalue_module = QValueActor( module=mlp, - spec=Composite(action=action_spec), + spec=Composite(action=action_spec).to(device), in_keys=["observation"], ) return qvalue_module -def make_dqn_model(env_name): - proof_environment = make_env(env_name, device="cpu") - qvalue_module = make_dqn_modules(proof_environment) +def make_dqn_model(env_name, device): + proof_environment = make_env(env_name, device=device) + qvalue_module = make_dqn_modules(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c03fb40f1ac..1898e679717 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -455,13 +455,18 @@ def __eq__(self, other): ) -@dataclass(repr=False) +@dataclass(repr=False, frozen=True) class CategoricalBox(Box): """A box of discrete, categorical values.""" n: int register = invertible_dict() + def __post_init__(self): + # n could be a numpy array or a tensor, making compile go a bit crazy + # We want to make sure we're working with a regular integer + self.__dict__["n"] = int(self.n) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox: return deepcopy(self) @@ -502,7 +507,7 @@ def from_nvec(nvec: torch.Tensor): return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)]) -@dataclass(repr=False) +@dataclass(repr=False, frozen=True) class BinaryBox(Box): """A box of n binary values.""" @@ -1695,7 +1700,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - @implement_for("torch", None, "2.1") + @implement_for("torch", None, "2.1", compilable=True) def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] @@ -1718,7 +1723,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: # out.scatter_(-1, m, 1) return out - @implement_for("torch", "2.1") + @implement_for("torch", "2.1", compilable=True) def rand(self, shape: torch.Size = None) -> torch.Tensor: # noqa: F811 if shape is None: shape = self.shape[:-1] @@ -3314,6 +3319,10 @@ def __init__( self.update_mask(mask) self._provisional_n = None + @torch.compiler.assume_constant_result + def _undefined_n(self): + return self.space.n == -1 + def enumerate(self) -> torch.Tensor: dtype = self.dtype if dtype is torch.bool: @@ -3379,7 +3388,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self.space.n < 0: + if self._undefined_n(): if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). " diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index a1879519271..6e8296a677a 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -55,6 +55,7 @@ class EGreedyModule(TensorDictModuleBase): Default is ``"action"``. action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. Default is ``None`` (corresponding to no mask). + device (torch.device, optional): the device of the exploration module. .. note:: It is crucial to incorporate a call to :meth:`~.step` in the training loop @@ -97,6 +98,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", action_mask_key: Optional[NestedKey] = None, + device: torch.device | None = None, ): if not isinstance(eps_init, float): warnings.warn("eps_init should be a float.") @@ -112,14 +114,18 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.as_tensor(eps_init)) - self.register_buffer("eps_end", torch.as_tensor(eps_end)) + self.register_buffer("eps_init", torch.as_tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.as_tensor(eps_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32)) + self.register_buffer( + "eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device) + ) if spec is not None: if not isinstance(spec, Composite) and len(self.out_keys) >= 1: spec = Composite({action_key: spec}, shape=spec.shape[:-1]) + if device is not None: + spec = spec.to(device) self._spec = spec @property @@ -147,7 +153,8 @@ def step(self, frames: int = 1) -> None: ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + expl = exploration_type() + if expl in (ExplorationType.RANDOM, None): if isinstance(self.action_key, tuple) and len(self.action_key) > 1: action_tensordict = tensordict.get(self.action_key[:-1]) action_key = self.action_key[-1] @@ -183,7 +190,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"Action mask key {self.action_mask_key} not found in {tensordict}." ) spec.update_mask(action_mask) - out = torch.where(cond, spec.rand().to(out.device), out) + r = spec.rand() + if r.device != out.device: + r = r.to(out.device) + out = torch.where(cond, r, out) else: raise RuntimeError("spec must be provided to the exploration wrapper.") action_tensordict.set(action_key, out)