Skip to content

Commit

Permalink
[Refactor] Deprecate recurrent_mode API to use decorators/CMs instead
Browse files Browse the repository at this point in the history
ghstack-source-id: 80f705e022abc111df3960fc09576d5e266ed4dd
Pull Request resolved: #2584
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent 0f59226 commit 14b2775
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 33 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer.
OnlineDTActor
RSSMPosterior
RSSMPrior
set_recurrent_mode
recurrent_mode

Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
24 changes: 24 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
DistributionalQValueActor,
OneHotCategorical,
QValueActor,
recurrent_mode,
SafeSequential,
WorldModelWrapper,
)
Expand Down Expand Up @@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs):


class TestBase:
def test_decorators(self):
class MyLoss(LossModule):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

loss = MyLoss()
loss.forward(None)
loss.actor_loss(None)
loss.something_loss(None)
assert not recurrent_mode()

@pytest.mark.parametrize("expand_dim", [None, 2])
@pytest.mark.parametrize("compare_against", [True, False])
@pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion")
Expand Down
54 changes: 53 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
OnlineDTActor,
ProbabilisticActor,
SafeModule,
set_recurrent_mode,
TanhDelta,
TanhNormal,
ValueOperator,
Expand Down Expand Up @@ -729,6 +730,31 @@ def test_errs(self):
with pytest.raises(KeyError, match="is_init"):
lstm_module(td)

@pytest.mark.parametrize("default_val", [False, True, None])
def test_set_recurrent_mode(self, default_val):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
default_recurrent_mode=default_val,
)
assert lstm_module.recurrent_mode is bool(default_val)
with set_recurrent_mode(True):
assert lstm_module.recurrent_mode
with set_recurrent_mode(False):
assert not lstm_module.recurrent_mode
with set_recurrent_mode("recurrent"):
assert lstm_module.recurrent_mode
with set_recurrent_mode("sequential"):
assert not lstm_module.recurrent_mode
assert lstm_module.recurrent_mode
assert not lstm_module.recurrent_mode
assert lstm_module.recurrent_mode
assert lstm_module.recurrent_mode is bool(default_val)

@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_set_temporal_mode(self):
lstm_module = LSTMModule(
input_size=3,
Expand All @@ -754,7 +780,8 @@ def test_python_cudnn(self):
num_layers=2,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
).set_recurrent_mode(True)
default_recurrent_mode=True,
)
obs = torch.rand(10, 20, 3)

hidden0 = torch.rand(10, 20, 2, 12)
Expand Down Expand Up @@ -1109,6 +1136,31 @@ def test_errs(self):
with pytest.raises(KeyError, match="is_init"):
gru_module(td)

@pytest.mark.parametrize("default_val", [False, True, None])
def test_set_recurrent_mode(self, default_val):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
default_recurrent_mode=default_val,
)
assert gru_module.recurrent_mode is bool(default_val)
with set_recurrent_mode(True):
assert gru_module.recurrent_mode
with set_recurrent_mode(False):
assert not gru_module.recurrent_mode
with set_recurrent_mode("recurrent"):
assert gru_module.recurrent_mode
with set_recurrent_mode("sequential"):
assert not gru_module.recurrent_mode
assert gru_module.recurrent_mode
assert not gru_module.recurrent_mode
assert gru_module.recurrent_mode
assert gru_module.recurrent_mode is bool(default_val)

@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_set_temporal_mode(self):
gru_module = GRUModule(
input_size=3,
Expand Down
6 changes: 4 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10885,7 +10885,8 @@ def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"):
in_keys=["observation", "rhs", "is_init"],
out_keys=["output", ("next", "rhs")],
device=device,
).set_recurrent_mode(True)
default_recurrent_mode=True,
)

def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
return LSTMModule(
Expand All @@ -10895,7 +10896,8 @@ def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
in_keys=["observation", "rhs_h", "rhs_c", "is_init"],
out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")],
device=device,
).set_recurrent_mode(True)
default_recurrent_mode=True,
)

def _make_batch(self, batch_size: int = 2, sequence_length: int = 5):
observation = torch.randn(batch_size, sequence_length + 1, 4)
Expand Down
23 changes: 23 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os
import pickle
import sys
import threading
import time
import traceback
import warnings
from contextlib import nullcontext
from copy import copy
from distutils.util import strtobool
from functools import wraps
Expand All @@ -32,6 +34,11 @@
from tensordict.utils import NestedKey
from torch import multiprocessing as mp

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling

LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
logger = logging.getLogger("torchrl")
logger.setLevel(getattr(logging, LOGGING_LEVEL))
Expand Down Expand Up @@ -827,3 +834,19 @@ def _make_ordinal_device(device: torch.device):
if device.type == "mps" and device.index is None:
return torch.device("mps", index=0)
return device


class _ContextManager:
def __init__(self):
self._mode: Any | None = None
self._lock = threading.Lock()

def get_mode(self) -> Any | None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
return self._mode

def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type
3 changes: 2 additions & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7411,7 +7411,8 @@ class BurnInTransform(Transform):
... hidden_size=10,
... in_keys=["observation", "hidden"],
... out_keys=["intermediate", ("next", "hidden")],
... ).set_recurrent_mode(True)
... default_recurrent_mode=True,
... )
>>> burn_in_transform = BurnInTransform(
... modules=[gru_module],
... burn_in=5,
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@
QValueActor,
QValueHook,
QValueModule,
recurrent_mode,
SafeModule,
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
SafeSequential,
set_recurrent_mode,
TanhModule,
ValueOperator,
VmapModule,
Expand Down
11 changes: 10 additions & 1 deletion torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from .rnn import GRU, GRUCell, GRUModule, LSTM, LSTMCell, LSTMModule
from .rnn import (
GRU,
GRUCell,
GRUModule,
LSTM,
LSTMCell,
LSTMModule,
recurrent_mode,
set_recurrent_mode,
)
from .sequence import SafeSequential
from .world_models import WorldModelWrapper
Loading

0 comments on commit 14b2775

Please sign in to comment.