Skip to content

Commit

Permalink
[Refactor] Update vmap imports to torch (#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 28, 2023
1 parent 8d8256a commit 4793640
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 8 deletions.
5 changes: 4 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@

_has_functorch = False
try:
from torch import vmap
try:
from torch import vmap
except ImportError:
from functorch import vmap

_has_functorch = True
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class SafeModule(TensorDictModule):
One can use a vmap operator to call the functional module. In this case the tensordict is expanded to match the
batch size (i.e. the tensordict isn't modified in-place anymore):
>>> # Model ensemble using vmap
>>> from functorch import vmap
>>> from torch import vmap
>>> params_repeat = params.expand(4, *params.shape)
>>> td_vmap = vmap(td_fmodule, (None, 0))(td.clone(), params_repeat)
>>> print(td_vmap)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class SafeSequential(TensorDictSequential, SafeModule):
shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous))
In the vmap case:
>>> from functorch import vmap
>>> from torch import vmap
>>> params = params.expand(4, *params.shape)
>>> td_vmap = vmap(td_module, (None, 0))(td, params)
>>> print(td_vmap)
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
from functorch import vmap
try:
from torch import vmap
except ImportError:
from functorch import vmap

FUNCTORCH_ERR = ""
_has_functorch = True
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
from functorch import vmap
try:
from torch import vmap
except ImportError:
from functorch import vmap

_has_functorch = True
err = ""
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
from functorch import vmap
try:
from torch import vmap
except ImportError:
from functorch import vmap

FUNCTORCH_ERR = ""
_has_functorch = True
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
from functorch import vmap
try:
from torch import vmap
except ImportError:
from functorch import vmap

_has_functorch = True
err = ""
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
from functorch import vmap
try:
from torch import vmap
except ImportError:
from functorch import vmap

FUNCTORCH_ERR = ""
_has_functorch = True
Expand Down

0 comments on commit 4793640

Please sign in to comment.