From cfa38bbd2d3485902f07e96246162e9fc5df57d7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 15 Nov 2024 17:45:26 +0000 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- sota-implementations/dqn/config_atari.yaml | 3 + sota-implementations/dqn/config_cartpole.yaml | 3 + sota-implementations/dqn/dqn_atari.py | 113 ++++++++++-------- sota-implementations/dqn/dqn_cartpole.py | 105 +++++++++------- 4 files changed, 135 insertions(+), 89 deletions(-) diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 50e374cef14..be645be2fb0 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -39,3 +39,6 @@ loss: gamma: 0.99 hard_update_freq: 10_000 num_updates: 1 + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 9a69762d6bd..cbc467b79c7 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -38,3 +38,6 @@ loss: gamma: 0.99 hard_update_freq: 50 num_updates: 1 + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 5d0162080e2..c8391658688 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -8,14 +8,14 @@ Deep Q-Learning Algorithm on Atari Environments. """ import tempfile -import time +import warnings import hydra import torch.nn import torch.optim import tqdm -from tensordict.nn import TensorDictSequential -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule, TensorDictSequential +from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -58,18 +58,6 @@ def main(cfg: "DictConfig"): # noqa: F821 greedy_module, ).to(device) - # Create the collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, frame_skip, device), - policy=model_explore, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - init_random_frames=init_random_frames, - ) - # Create the replay buffer if cfg.buffer.scratch_dir is None: tempdir = tempfile.TemporaryDirectory() @@ -127,25 +115,68 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + def update(sampled_tensordict): + loss_td = loss_module(sampled_tensordict) + q_loss = loss_td["loss"] + optimizer.zero_grad() + q_loss.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad + ) + optimizer.step() + target_net_updater.step() + return q_loss.detach() + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create the collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, frame_skip, device), + policy=model_explore, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + init_random_frames=init_random_frames, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.loss.cudagraphs, + ) + # Main loop collected_frames = 0 - start_time = time.time() - sampling_start = time.time() num_updates = cfg.loss.num_updates max_grad = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes q_losses = torch.zeros(num_updates, device=device) pbar = tqdm.tqdm(total=total_frames) - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() * frame_skip collected_frames += current_frames greedy_module.step(current_frames) - replay_buffer.extend(data) + with timeit("rb - extend"): + replay_buffer.extend(data) # Get and log training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] @@ -167,24 +198,13 @@ def main(cfg: "DictConfig"): # noqa: F821 continue # optimization steps - training_start = time.time() for j in range(num_updates): - - sampled_tensordict = replay_buffer.sample() - sampled_tensordict = sampled_tensordict.to(device) - - loss_td = loss_module(sampled_tensordict) - q_loss = loss_td["loss"] - optimizer.zero_grad() - q_loss.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=max_grad - ) - optimizer.step() - target_net_updater.step() - q_losses[j].copy_(q_loss.detach()) - - training_time = time.time() - training_start + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + q_loss = update(sampled_tensordict) + q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( @@ -193,31 +213,33 @@ def main(cfg: "DictConfig"): # noqa: F821 / frames_per_batch, "train/q_loss": q_losses.mean().item(), "train/epsilon": greedy_module.eps, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() - eval_start = time.time() test_rewards = eval_model( model, test_env, num_episodes=num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "eval/reward": test_rewards, - "eval/eval_time": eval_time, } ) model.train() + if i % 200 == 0: + timeit.print() + log_info.update(timeit.todict(prefix="time")) + timeit.erase() + # Log all the information if logger: for key, value in log_info.items(): @@ -225,16 +247,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # update weights of the inference policy collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 8149c700958..0e216ffa734 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -2,15 +2,16 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import time + +import warnings import hydra import torch.nn import torch.optim import tqdm -from tensordict.nn import TensorDictSequential -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule, TensorDictSequential +from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.envs import ExplorationType, set_exploration_type @@ -46,18 +47,6 @@ def main(cfg: "DictConfig"): # noqa: F821 greedy_module, ).to(device) - # Create the collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, "cpu"), - policy=model_explore, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device="cpu", - storing_device="cpu", - max_frames_per_traj=-1, - init_random_frames=cfg.collector.init_random_frames, - ) - # Create the replay buffer replay_buffer = TensorDictReplayBuffer( pin_memory=False, @@ -109,9 +98,47 @@ def main(cfg: "DictConfig"): # noqa: F821 ), ) + def update(sampled_tensordict): + loss_td = loss_module(sampled_tensordict) + q_loss = loss_td["loss"] + optimizer.zero_grad() + q_loss.backward() + optimizer.step() + target_net_updater.step() + return q_loss.detach() + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create the collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, "cpu"), + policy=model_explore, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device="cpu", + storing_device="cpu", + max_frames_per_traj=-1, + init_random_frames=cfg.collector.init_random_frames, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.loss.cudagraphs, + ) + # Main loop collected_frames = 0 - start_time = time.time() num_updates = cfg.loss.num_updates batch_size = cfg.buffer.batch_size test_interval = cfg.logger.test_interval @@ -119,17 +146,20 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - sampling_start = time.time() q_losses = torch.zeros(num_updates, device=device) - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start pbar.update(data.numel()) data = data.reshape(-1) current_frames = data.numel() - replay_buffer.extend(data) + + with timeit("rb - extend"): + replay_buffer.extend(data) collected_frames += current_frames greedy_module.step(current_frames) @@ -154,18 +184,13 @@ def main(cfg: "DictConfig"): # noqa: F821 continue # optimization steps - training_start = time.time() for j in range(num_updates): - sampled_tensordict = replay_buffer.sample(batch_size) - sampled_tensordict = sampled_tensordict.to(device) - loss_td = loss_module(sampled_tensordict) - q_loss = loss_td["loss"] - optimizer.zero_grad() - q_loss.backward() - optimizer.step() - target_net_updater.step() - q_losses[j].copy_(q_loss.detach()) - training_time = time.time() - training_start + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample(batch_size) + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + q_loss = update(sampled_tensordict) + q_losses[j].copy_(q_loss) # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( @@ -174,29 +199,31 @@ def main(cfg: "DictConfig"): # noqa: F821 / frames_per_batch, "train/q_loss": q_losses.mean().item(), "train/epsilon": greedy_module.eps, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() - eval_start = time.time() test_rewards = eval_model(model, test_env, num_test_episodes) - eval_time = time.time() - eval_start model.train() log_info.update( { "eval/reward": test_rewards, - "eval/eval_time": eval_time, } ) + if i % 200 == 0: + timeit.print() + log_info.update(timeit.todict(prefix="time")) + timeit.erase() + # Log all the information if logger: for key, value in log_info.items(): @@ -204,14 +231,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # update weights of the inference policy collector.update_policy_weights_() - sampling_start = time.time() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": From 3da531a93b27e302ba0c27bbeb9dff32163016a3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 15:18:36 -0800 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- sota-implementations/dqn/dqn_atari.py | 5 +++-- sota-implementations/dqn/utils_atari.py | 12 +++++++----- .../modules/tensordict_module/exploration.py | 17 +++++++++++++---- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index d8e5047e8b3..0e07462220f 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -48,17 +48,18 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Make the components - model = make_dqn_model(cfg.env.env_name, frame_skip) + model = make_dqn_model(cfg.env.env_name, frame_skip, device=device) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, eps_end=cfg.collector.eps_end, spec=model.spec, + device=device, ) model_explore = TensorDictSequential( model, greedy_module, - ).to(device) + ) # Create the replay buffer if cfg.buffer.scratch_dir is None: diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 1e5440a54b6..a135b78803d 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -61,7 +61,7 @@ def make_env(env_name, frame_skip, device, is_test=False): # -------------------------------------------------------------------- -def make_dqn_modules_pixels(proof_environment): +def make_dqn_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -75,13 +75,15 @@ def make_dqn_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - cnn_output = cnn(torch.ones(input_shape)) + cnn_output = cnn(torch.ones(input_shape, device=device)) mlp = MLP( in_features=cnn_output.shape[-1], activation_class=torch.nn.ReLU, out_features=num_actions, num_cells=[512], + device=device, ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), @@ -91,9 +93,9 @@ def make_dqn_modules_pixels(proof_environment): return qvalue_module -def make_dqn_model(env_name, frame_skip): - proof_environment = make_env(env_name, frame_skip, device="cpu") - qvalue_module = make_dqn_modules_pixels(proof_environment) +def make_dqn_model(env_name, frame_skip, device): + proof_environment = make_env(env_name, frame_skip, device=device) + qvalue_module = make_dqn_modules_pixels(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index a1879519271..05c67b40c3f 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -55,6 +55,7 @@ class EGreedyModule(TensorDictModuleBase): Default is ``"action"``. action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. Default is ``None`` (corresponding to no mask). + device (torch.device, optional): the device of the exploration module. .. note:: It is crucial to incorporate a call to :meth:`~.step` in the training loop @@ -97,6 +98,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", action_mask_key: Optional[NestedKey] = None, + device: torch.device | None = None, ): if not isinstance(eps_init, float): warnings.warn("eps_init should be a float.") @@ -112,14 +114,18 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.as_tensor(eps_init)) - self.register_buffer("eps_end", torch.as_tensor(eps_end)) + self.register_buffer("eps_init", torch.as_tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.as_tensor(eps_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32)) + self.register_buffer( + "eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device) + ) if spec is not None: if not isinstance(spec, Composite) and len(self.out_keys) >= 1: spec = Composite({action_key: spec}, shape=spec.shape[:-1]) + if device is not None: + spec = spec.to(device) self._spec = spec @property @@ -183,7 +189,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"Action mask key {self.action_mask_key} not found in {tensordict}." ) spec.update_mask(action_mask) - out = torch.where(cond, spec.rand().to(out.device), out) + r = spec.rand() + if r.device != out.device: + r = r.to(out.device) + out = torch.where(cond, r, out) else: raise RuntimeError("spec must be provided to the exploration wrapper.") action_tensordict.set(action_key, out) From ed89a1083ed0e8e194b4cbfa5dc421611dc6abe9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 15:40:06 -0800 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c03fb40f1ac..7dfe32b93be 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1695,7 +1695,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - @implement_for("torch", None, "2.1") + @implement_for("torch", None, "2.1", compilable=True) def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] @@ -1718,7 +1718,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: # out.scatter_(-1, m, 1) return out - @implement_for("torch", "2.1") + @implement_for("torch", "2.1", compilable=True) def rand(self, shape: torch.Size = None) -> torch.Tensor: # noqa: F811 if shape is None: shape = self.shape[:-1] From 89c8d98b28a6092195716cff2e7d9ed0b12b4575 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 15:51:50 -0800 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- sota-implementations/dqn/utils_atari.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index a135b78803d..9b10b63ef6d 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -39,6 +39,7 @@ def make_env(env_name, frame_skip, device, is_test=False): from_pixels=True, pixels_only=False, device=device, + categorical_action_encoding=True, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) From 0dc4622accb84ed54270e3d7ce21e853733db307 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 15:55:53 -0800 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7dfe32b93be..4563fd3ca21 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -455,7 +455,7 @@ def __eq__(self, other): ) -@dataclass(repr=False) +@dataclass(repr=False, frozen=True) class CategoricalBox(Box): """A box of discrete, categorical values.""" @@ -502,7 +502,7 @@ def from_nvec(nvec: torch.Tensor): return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)]) -@dataclass(repr=False) +@dataclass(repr=False, frozen=True) class BinaryBox(Box): """A box of n binary values.""" @@ -3313,6 +3313,7 @@ def __init__( ) self.update_mask(mask) self._provisional_n = None + self._undefined_n = self.space.n < 0 def enumerate(self) -> torch.Tensor: dtype = self.dtype @@ -3379,7 +3380,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self.space.n < 0: + if self._undefined_n: if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). " From e5a358b596ca2653455d36a4ebca95401e717054 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 17:23:37 -0800 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- sota-implementations/dqn/dqn_atari.py | 9 +++++---- sota-implementations/dqn/dqn_cartpole.py | 9 ++++++--- sota-implementations/dqn/utils_atari.py | 2 +- sota-implementations/dqn/utils_cartpole.py | 11 ++++++----- torchrl/data/tensor_specs.py | 12 ++++++++++-- torchrl/modules/tensordict_module/exploration.py | 3 ++- 6 files changed, 30 insertions(+), 16 deletions(-) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 0e07462220f..255b6b2ee65 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -156,7 +156,9 @@ def update(sampled_tensordict): storing_device=device, max_frames_per_traj=-1, init_random_frames=init_random_frames, - compile_policy={"mode": compile_mode} if compile_mode is not None else False, + compile_policy={"mode": compile_mode, "fullgraph": True} + if compile_mode is not None + else False, cudagraph_policy=cfg.compile.cudagraphs, ) @@ -212,9 +214,8 @@ def update(sampled_tensordict): # Get and log q-values, loss, epsilon, sampling time and training time log_info.update( { - "train/q_values": (data["action_value"] * data["action"]).sum().item() - / frames_per_batch, - "train/q_loss": q_losses.mean().item(), + "train/q_values": data["chosen_action_value"].sum() / frames_per_batch, + "train/q_loss": q_losses.mean(), "train/epsilon": greedy_module.eps, } ) diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index e51a538d882..89a1e04d586 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -36,18 +36,19 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(device) # Make the components - model = make_dqn_model(cfg.env.env_name) + model = make_dqn_model(cfg.env.env_name, device=device) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, eps_end=cfg.collector.eps_end, spec=model.spec, + device=device, ) model_explore = TensorDictSequential( model, greedy_module, - ).to(device) + ) # Create the replay buffer replay_buffer = TensorDictReplayBuffer( @@ -135,7 +136,9 @@ def update(sampled_tensordict): storing_device="cpu", max_frames_per_traj=-1, init_random_frames=cfg.collector.init_random_frames, - compile_policy={"mode": compile_mode} if compile_mode is not None else False, + compile_policy={"mode": compile_mode, "fullgraph": True} + if compile_mode is not None + else False, cudagraph_policy=cfg.compile.cudagraphs, ) diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 9b10b63ef6d..0956dfeb2ac 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -88,7 +88,7 @@ def make_dqn_modules_pixels(proof_environment, device): ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), - spec=Composite(action=action_spec), + spec=Composite(action=action_spec).to(device), in_keys=["pixels"], ) return qvalue_module diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index d378f1ec76b..c49ff15f5fc 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -31,7 +31,7 @@ def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False): # -------------------------------------------------------------------- -def make_dqn_modules(proof_environment): +def make_dqn_modules(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -45,19 +45,20 @@ def make_dqn_modules(proof_environment): activation_class=torch.nn.ReLU, out_features=num_outputs, num_cells=[120, 84], + device=device, ) qvalue_module = QValueActor( module=mlp, - spec=Composite(action=action_spec), + spec=Composite(action=action_spec).to(device), in_keys=["observation"], ) return qvalue_module -def make_dqn_model(env_name): - proof_environment = make_env(env_name, device="cpu") - qvalue_module = make_dqn_modules(proof_environment) +def make_dqn_model(env_name, device): + proof_environment = make_env(env_name, device=device) + qvalue_module = make_dqn_modules(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4563fd3ca21..1898e679717 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -462,6 +462,11 @@ class CategoricalBox(Box): n: int register = invertible_dict() + def __post_init__(self): + # n could be a numpy array or a tensor, making compile go a bit crazy + # We want to make sure we're working with a regular integer + self.__dict__["n"] = int(self.n) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox: return deepcopy(self) @@ -3313,7 +3318,10 @@ def __init__( ) self.update_mask(mask) self._provisional_n = None - self._undefined_n = self.space.n < 0 + + @torch.compiler.assume_constant_result + def _undefined_n(self): + return self.space.n == -1 def enumerate(self) -> torch.Tensor: dtype = self.dtype @@ -3380,7 +3388,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self._undefined_n: + if self._undefined_n(): if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). " diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 05c67b40c3f..6e8296a677a 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -153,7 +153,8 @@ def step(self, frames: int = 1) -> None: ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + expl = exploration_type() + if expl in (ExplorationType.RANDOM, None): if isinstance(self.action_key, tuple) and len(self.action_key) > 1: action_tensordict = tensordict.get(self.action_key[:-1]) action_key = self.action_key[-1]