From dd70b78389504634ba3cf2bcf1ea517c606c61a4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jul 2024 16:35:23 +0100 Subject: [PATCH] [BugFix] Expose MARL modules (#2321) --- torchrl/modules/models/multiagent.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 1a92c4b33eb..6ccc4721678 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -70,8 +70,23 @@ def __init__( break self.initialized = initialized self._make_params(agent_networks) + # We make sure all params and buffers are on 'meta' device + # To do this, we set the device keyword arg to 'meta', we also temporarily change + # the default device. Finally, we convert all params to 'meta' tensors that are not params. kwargs["device"] = "meta" - self.__dict__["_empty_net"] = self._build_single_net(**kwargs) + with torch.device("meta"): + try: + self._empty_net = self._build_single_net(**kwargs) + except NotImplementedError as err: + if "Cannot copy out of meta tensor" in str(err): + raise RuntimeError( + "The network was built using `factory().to(device), build the network directly " + "on device using `factory(device=device)` instead." + ) + # Remove all parameters + TensorDict.from_module(self._empty_net).data.to("meta").to_module( + self._empty_net + ) @property def vmap_randomness(self): @@ -225,7 +240,7 @@ def from_stateful_net(self, stateful_net: nn.Module): self.params.data.update_(params.data) def __repr__(self): - empty_net = self.__dict__["_empty_net"] + empty_net = self._empty_net with self.params.to_module(empty_net): module_repr = indent(str(empty_net), 4 * " ") n_agents = indent(f"n_agents={self.n_agents}", 4 * " ")