Skip to content

Commit

Permalink
[BugFix] Reinitialize vmap callers after reset of vmap randomness
Browse files Browse the repository at this point in the history
ghstack-source-id: 0598afe1ed3bb118054f5021b35a8681ce964615
Pull Request resolved: pytorch#2315
  • Loading branch information
vmoens committed Jul 24, 2024
1 parent f840a1a commit 21b297e
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 21 deletions.
3 changes: 3 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def __init__(self):
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self._make_vmap()

def _make_vmap(self):
self.vmap_model = _vmap_func(
self.model,
(None, 0),
Expand Down
28 changes: 17 additions & 11 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple

import torch
from tensordict import is_tensor_collection, TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
Expand Down Expand Up @@ -541,16 +540,16 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
@property
def vmap_randomness(self):
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
main_modules = list(self.__dict__.values()) + list(self.children())
modules = (
module
for main_module in main_modules
if isinstance(main_module, nn.Module)
for module in main_module.modules()
)
for val in modules:
if isinstance(val, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
break
else:
self._vmap_randomness = "error"
Expand All @@ -559,6 +558,7 @@ def vmap_randomness(self):

def set_vmap_randomness(self, value):
self._vmap_randomness = value
self._make_vmap()

@staticmethod
def _make_meta_params(param):
Expand All @@ -570,6 +570,12 @@ def _make_meta_params(param):
pd = nn.Parameter(pd, requires_grad=False)
return pd

def _make_vmap(self):
"""Caches the the vmap callers to reduce the overhead at runtime."""
raise NotImplementedError(
f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}."
)


class _make_target_param:
def __init__(self, clone):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,16 @@ def __init__(
"log_alpha_prime",
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy(self):
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,13 @@ def __init__(

self._target_entropy = target_entropy
self._action_spec = action_spec
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy_buffer(self):
Expand Down
6 changes: 4 additions & 2 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,15 @@ def __init__(
self._action_spec = action_spec
self.target_entropy_buffer = None
self.gSDE = gSDE
self._make_vmap()
self.reduction = reduction

self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,13 @@ def __init__(
self.loss_function = loss_function
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def device(self) -> torch.device:
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def __init__(
self.gSDE = gSDE
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
Expand Down
12 changes: 9 additions & 3 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,17 @@ def __init__(
)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
self.qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -1101,10 +1104,13 @@ def __init__(
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,16 @@ def __init__(
self.register_buffer("min_action", low)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ def __init__(
high = high.to(device)
self.register_buffer("max_action", high)
self.register_buffer("min_action", low)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down

0 comments on commit 21b297e

Please sign in to comment.