Skip to content

Commit

Permalink
[BugFix] Vmap randomness for value estimator (#1942)
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 authored Feb 21, 2024
1 parent 23bf315 commit 4080cf3
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import step_mdp

from torchrl.objectives.utils import _vmap_func, hold_out_net
from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST
from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td0_return_estimate,
Expand Down Expand Up @@ -78,6 +78,7 @@ def _call_value_nets(
single_call: bool,
value_key: NestedKey,
detach_next: bool,
vmap_randomness: str = "error",
):
in_keys = value_net.in_keys
if single_call:
Expand Down Expand Up @@ -141,9 +142,11 @@ def _call_value_nets(
)
elif params is not None:
params_stack = torch.stack([params, next_params], 0).contiguous()
data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack)
data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
data_in, params_stack
)
else:
data_out = vmap(value_net, (0,))(data_in)
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
value_est = data_out.get(value_key)
value, value_ = value_est[0], value_est[1]
data.set(value_key, value)
Expand Down Expand Up @@ -214,6 +217,7 @@ class _AcceptedKeys:

default_keys = _AcceptedKeys()
value_network: Union[TensorDictModule, Callable]
_vmap_randomness = None

@property
def advantage_key(self):
Expand Down Expand Up @@ -428,6 +432,28 @@ def _next_value(self, tensordict, target_params, kwargs):
next_value = step_td.get(self.tensor_keys.value)
return next_value

@property
def vmap_randomness(self):
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"

return self._vmap_randomness

def set_vmap_randomness(self, value):
self._vmap_randomness = value


class TD0Estimator(ValueEstimatorBase):
"""Temporal Difference (TD(0)) estimate of advantage function.
Expand Down Expand Up @@ -589,6 +615,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -790,6 +817,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1001,6 +1029,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1247,6 +1276,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1329,6 +1359,7 @@ def value_estimate(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down Expand Up @@ -1575,6 +1606,7 @@ def forward(
single_call=self.shifted,
value_key=self.tensor_keys.value,
detach_next=True,
vmap_randomness=self.vmap_randomness,
)
else:
value = tensordict.get(self.tensor_keys.value)
Expand Down

0 comments on commit 4080cf3

Please sign in to comment.