Skip to content

Commit

Permalink
[BugFix] Make casting to 'meta' device uniform across cost modules (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 27, 2023
1 parent 0f93943 commit 38d9cb7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(

# check that the model has parameters
params = TensorDict.from_module(actor)
with params.apply(_stateless_param).to_module(actor):
with params.apply(_stateless_param, device="meta").to_module(actor):
# copy a stateless actor
self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
Expand Down
26 changes: 13 additions & 13 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def _compare_and_expand(param):

# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(_make_meta_params, device=torch.device("meta")).to_module(
module
):
with params.apply(
self._make_meta_params, device=torch.device("meta")
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)

Expand Down Expand Up @@ -435,6 +435,16 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams

return self

@staticmethod
def _make_meta_params(param):
is_param = isinstance(param, nn.Parameter)

pd = param.detach().to("meta")

if is_param:
pd = nn.Parameter(pd, requires_grad=False)
return pd


class _make_target_param:
def __init__(self, clone):
Expand All @@ -446,13 +456,3 @@ def __call__(self, x):
x.data.clone() if self.clone else x.data, requires_grad=False
)
return x.data.clone() if self.clone else x.data


def _make_meta_params(param):
is_param = isinstance(param, nn.Parameter)

pd = param.detach().to("meta")

if is_param:
pd = nn.Parameter(pd, requires_grad=False)
return pd
2 changes: 1 addition & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(

actor_critic = ActorCriticWrapper(actor_network, value_network)
params = TensorDict.from_module(actor_critic)
params_meta = params.detach().to("meta")
params_meta = params.apply(self._make_meta_params, device=torch.device("meta"))
with params_meta.to_module(actor_critic):
self.actor_critic = deepcopy(actor_critic)

Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def __init__(

global_value_network = SafeSequential(local_value_network, mixer_network)
params = TensorDict.from_module(global_value_network)
with params.detach().to("meta").to_module(global_value_network):
with params.apply(
self._make_meta_params, device=torch.device("meta")
).to_module(global_value_network):
self.global_value_network = deepcopy(global_value_network)

self.convert_to_functional(
Expand Down

0 comments on commit 38d9cb7

Please sign in to comment.