From 38d9cb7c73f9382ed6e9d916c7609a222aa9dcab Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 27 Nov 2023 13:30:46 +0000 Subject: [PATCH] [BugFix] Make casting to 'meta' device uniform across cost modules (#1715) --- torchrl/envs/transforms/rlhf.py | 2 +- torchrl/objectives/common.py | 26 ++++++++++++------------- torchrl/objectives/ddpg.py | 2 +- torchrl/objectives/multiagent/qmixer.py | 4 +++- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 48464d9f9c4..79ee94318cb 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -112,7 +112,7 @@ def __init__( # check that the model has parameters params = TensorDict.from_module(actor) - with params.apply(_stateless_param).to_module(actor): + with params.apply(_stateless_param, device="meta").to_module(actor): # copy a stateless actor self.__dict__["functional_actor"] = deepcopy(actor) # we need to register these params as buffer to have `to` and similar diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 1a99ffca108..76e5ef10900 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -289,9 +289,9 @@ def _compare_and_expand(param): # set the functional module: we need to convert the params to non-differentiable params # otherwise they will appear twice in parameters - with params.apply(_make_meta_params, device=torch.device("meta")).to_module( - module - ): + with params.apply( + self._make_meta_params, device=torch.device("meta") + ).to_module(module): # avoid buffers and params being exposed self.__dict__[module_name] = deepcopy(module) @@ -435,6 +435,16 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams return self + @staticmethod + def _make_meta_params(param): + is_param = isinstance(param, nn.Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = nn.Parameter(pd, requires_grad=False) + return pd + class _make_target_param: def __init__(self, clone): @@ -446,13 +456,3 @@ def __call__(self, x): x.data.clone() if self.clone else x.data, requires_grad=False ) return x.data.clone() if self.clone else x.data - - -def _make_meta_params(param): - is_param = isinstance(param, nn.Parameter) - - pd = param.detach().to("meta") - - if is_param: - pd = nn.Parameter(pd, requires_grad=False) - return pd diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 7d94a5eb07b..a72b84f69e4 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -198,7 +198,7 @@ def __init__( actor_critic = ActorCriticWrapper(actor_network, value_network) params = TensorDict.from_module(actor_critic) - params_meta = params.detach().to("meta") + params_meta = params.apply(self._make_meta_params, device=torch.device("meta")) with params_meta.to_module(actor_critic): self.actor_critic = deepcopy(actor_critic) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 61abab6216f..35e03d35744 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -213,7 +213,9 @@ def __init__( global_value_network = SafeSequential(local_value_network, mixer_network) params = TensorDict.from_module(global_value_network) - with params.detach().to("meta").to_module(global_value_network): + with params.apply( + self._make_meta_params, device=torch.device("meta") + ).to_module(global_value_network): self.global_value_network = deepcopy(global_value_network) self.convert_to_functional(