Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] SafeModule not safely handling specs #1352

Merged
merged 12 commits into from
Jul 5, 2023
Merged
Prev Previous commit
Next Next commit
probabilistic
Signed-off-by: Matteo Bettini <matbet@meta.com>
  • Loading branch information
matteobettini committed Jul 4, 2023
commit 8f0ad2318d9395347a879602d1054d3c90fa9ff6
4 changes: 3 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase
from tensordict.utils import unravel_key_list

from torch import nn

Expand Down Expand Up @@ -226,13 +227,14 @@ def register_spec(self, safe, spec):
elif spec is None:
spec = CompositeSpec()

spec_keys = spec.keys(True, True)
spec_keys = unravel_key_list(list(spec.keys(True, True)))
if set(spec_keys) != set(self.out_keys):
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec_keys:
spec[key] = None

spec_keys = unravel_key_list(list(spec.keys(True, True)))
if set(spec_keys) != set(self.out_keys):
raise RuntimeError(
f"spec keys and out_keys do not match, got: {spec_keys} and {set(self.out_keys)} respectively"
Expand Down
14 changes: 9 additions & 5 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TensorDictModule,
)
from tensordict.tensordict import TensorDictBase
from tensordict.utils import unravel_key_list

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.distributions import Delta
Expand Down Expand Up @@ -120,7 +121,8 @@ def __init__(
cache_dist=cache_dist,
n_empirical_estimate=n_empirical_estimate,
)

if spec is not None:
spec = spec.clone()
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
elif spec is not None and not isinstance(spec, CompositeSpec):
Expand All @@ -136,15 +138,17 @@ def __init__(
elif spec is None:
spec = CompositeSpec()

if set(spec.keys(True, True)) != set(self.out_keys):
spec_keys = unravel_key_list(list(spec.keys(True, True)))
if set(spec_keys) != set(self.out_keys):
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
if key not in spec_keys:
spec[key] = None

if set(spec.keys(True, True)) != set(self.out_keys):
spec_keys = unravel_key_list(list(spec.keys(True, True)))
if spec_keys != set(self.out_keys):
raise RuntimeError(
f"spec keys and out_keys do not match, got: {set(spec.keys(True, True))} and {set(self.out_keys)} respectively"
f"spec keys and out_keys do not match, got: {spec_keys} and {set(self.out_keys)} respectively"
)

self._spec = spec
Expand Down