Skip to content

Commit

Permalink
[Feature] Allow users to add random modules for vmap randomness detec…
Browse files Browse the repository at this point in the history
…tion

ghstack-source-id: 8a14ade21a8369a2da2613557a533f584865fbef
Pull Request resolved: #2317
  • Loading branch information
vmoens committed Jul 24, 2024
1 parent 447970d commit 1ca33a2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ Since the calls to `vmap` are buried down the loss modules, TorchRL
provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see
:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information.

``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in
other cases. By default, only a limited number of modules are listed as random, but the list can be extended
using the :func:`~torchrl.objectives.common.add_random_module` function.

Training value functions
------------------------

Expand Down
11 changes: 11 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,12 @@ def vmap_randomness(self):
If `"different"`, every element of the batch along which vmap is being called will
behave differently. If `"same"`, vmaps will copy the same result across all elements.
``vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in
other cases. By default, only a limited number of modules are listed as random, but the list can be extended
using the :func:`~torchrl.objectives.common.add_random_module` function.
This property supports setting its value.
"""
if self._vmap_randomness is None:
main_modules = list(self.__dict__.values()) + list(self.children())
Expand Down Expand Up @@ -603,3 +608,9 @@ def __call__(self, x):
x.data.clone() if self.clone else x.data, requires_grad=False
)
return x.data.clone() if self.clone else x.data


def add_ramdom_module(module):
"""Adds a random module to the list of modules that will be detected by :meth:`~torchrl.objectives.LossModule.vmap_randomness` as random."""
global RANDOM_MODULE_LIST
RANDOM_MODULE_LIST = RANDOM_MODULE_LIST + (module,)

0 comments on commit 1ca33a2

Please sign in to comment.