Skip to content

Commit

Permalink
[BugFix] SafeModule not safely handling specs (pytorch#1352)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini authored Jul 5, 2023
1 parent e09d2b3 commit e4fef6b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 10 deletions.
50 changes: 48 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
UnboundedContinuousTensorSpec,
)
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import LSTMModule, NormalParamWrapper, SafeModule, TanhNormal
from torchrl.modules import (
AdditiveGaussianWrapper,
LSTMModule,
NormalParamWrapper,
SafeModule,
TanhNormal,
ValueOperator,
)
from torchrl.modules.tensordict_module.common import (
ensure_tensordict_compatible,
is_tensordict_compatible,
Expand All @@ -27,7 +34,7 @@
SafeProbabilisticTensorDictSequential,
)
from torchrl.modules.tensordict_module.sequence import SafeSequential

from torchrl.objectives import DDPGLoss

_has_functorch = False
try:
Expand Down Expand Up @@ -1730,6 +1737,45 @@ def test_multi_consecutive(self, shape):
)


def test_safe_specs():

out_key = ("a", "b")
spec = CompositeSpec(CompositeSpec({out_key: UnboundedContinuousTensorSpec()}))
original_spec = spec.clone()
mod = SafeModule(
module=nn.Linear(3, 1),
spec=spec,
out_keys=[out_key, ("other", "key")],
in_keys=[],
)
assert original_spec == spec
assert original_spec[out_key] == mod.spec[out_key]


def test_actor_critic_specs():
action_key = ("agents", "action")
spec = CompositeSpec(
CompositeSpec({action_key: UnboundedContinuousTensorSpec(shape=(3,))})
)
policy_module = TensorDictModule(
nn.Linear(3, 1),
in_keys=[("agents", "observation")],
out_keys=[action_key],
)
original_spec = spec.clone()
module = AdditiveGaussianWrapper(policy_module, spec=spec, action_key=action_key)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation"), action_key],
out_keys=[("agents", "state_action_value")],
)
assert original_spec == spec
assert module.spec == spec
DDPGLoss(actor_network=module, value_network=value_module)
assert original_spec == spec
assert module.spec == spec


def test_vmapmodule():
lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
sample_in = torch.ones((10, 3, 2))
Expand Down
8 changes: 7 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,7 +2582,13 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
item = CompositeSpec(item, shape=shape)
if item is not None:
if self._device is None:
self._device = item.device
try:
self._device = item.device
except RuntimeError as err:
if DEVICE_ERR_MSG in str(err):
self._device = item._device
else:
raise err
self[k] = item

@property
Expand Down
6 changes: 4 additions & 2 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def __init__(
self.register_spec(safe=safe, spec=spec)

def register_spec(self, safe, spec):
if spec is not None:
spec = spec.clone()
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
elif spec is not None and not isinstance(spec, CompositeSpec):
Expand All @@ -230,8 +232,8 @@ def register_spec(self, safe, spec):
out_keys = set(unravel_key_list(self.out_keys))
if spec_keys != out_keys:
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
for key in out_keys:
if key not in spec_keys:
spec[key] = None
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
if spec_keys != out_keys:
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def __init__(
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
if action_key not in self._spec.keys():
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
elif hasattr(self.td_module, "spec"):
self._spec = self.td_module.spec.clone()
if action_key not in self._spec.keys():
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
else:
self._spec = CompositeSpec({key: None for key in policy.out_keys})
Expand Down
7 changes: 4 additions & 3 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def __init__(
cache_dist=cache_dist,
n_empirical_estimate=n_empirical_estimate,
)

if spec is not None:
spec = spec.clone()
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
elif spec is not None and not isinstance(spec, CompositeSpec):
Expand All @@ -139,8 +140,8 @@ def __init__(
out_keys = set(unravel_key_list(self.out_keys))
if spec_keys != out_keys:
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
for key in out_keys:
if key not in spec_keys:
spec[key] = None
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))

Expand Down

0 comments on commit e4fef6b

Please sign in to comment.