Skip to content

Commit

Permalink
[BugFix] Expose MARL modules (pytorch#2321)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 25, 2024
1 parent 8a74642 commit dd70b78
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,23 @@ def __init__(
break
self.initialized = initialized
self._make_params(agent_networks)
# We make sure all params and buffers are on 'meta' device
# To do this, we set the device keyword arg to 'meta', we also temporarily change
# the default device. Finally, we convert all params to 'meta' tensors that are not params.
kwargs["device"] = "meta"
self.__dict__["_empty_net"] = self._build_single_net(**kwargs)
with torch.device("meta"):
try:
self._empty_net = self._build_single_net(**kwargs)
except NotImplementedError as err:
if "Cannot copy out of meta tensor" in str(err):
raise RuntimeError(
"The network was built using `factory().to(device), build the network directly "
"on device using `factory(device=device)` instead."
)
# Remove all parameters
TensorDict.from_module(self._empty_net).data.to("meta").to_module(
self._empty_net
)

@property
def vmap_randomness(self):
Expand Down Expand Up @@ -225,7 +240,7 @@ def from_stateful_net(self, stateful_net: nn.Module):
self.params.data.update_(params.data)

def __repr__(self):
empty_net = self.__dict__["_empty_net"]
empty_net = self._empty_net
with self.params.to_module(empty_net):
module_repr = indent(str(empty_net), 4 * " ")
n_agents = indent(f"n_agents={self.n_agents}", 4 * " ")
Expand Down

0 comments on commit dd70b78

Please sign in to comment.