Skip to content

Commit

Permalink
[Feature] CompositeSpec.lock (pytorch#1143)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 9, 2023
1 parent 6d030c9 commit 71c1657
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,40 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype):
assert ts["nested_cp"]["act"] is not None


@pytest.mark.parametrize("recurse", [True, False])
def test_lock(recurse):
shape = [3, 4, 5]
spec = CompositeSpec(
a=CompositeSpec(
b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2]
),
shape=shape[:1],
)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
assert not spec.locked
spec.lock_(recurse=recurse)
assert spec.locked
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
spec["a"] = spec["a"].clone()
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
spec["a"].set("b", spec["a", "b"].clone())
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
spec["a", "b"] = spec["a", "b"].clone()
spec["a"].set("b", spec["a", "b"].clone())
spec.unlock_(recurse=recurse)
spec["a"] = spec["a"].clone()
spec["a", "b"] = spec["a", "b"].clone()
spec["a"].set("b", spec["a", "b"].clone())


def test_keys_to_empty_composite_spec():
keys = [("key1", "out"), ("key1", "in"), "key2", ("key1", "subkey1", "subkey2")]
composite = _keys_to_empty_composite_spec(keys)
Expand Down
67 changes: 67 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2422,6 +2422,7 @@ class CompositeSpec(TensorSpec):
@classmethod
def __new__(cls, *args, **kwargs):
cls._device = torch.device("cpu")
cls._locked = False
return super().__new__(cls)

@property
Expand All @@ -2447,6 +2448,8 @@ def ndimension(self):
return len(self.shape)

def set(self, name, spec):
if self.locked:
raise RuntimeError("Cannot modify a locked CompositeSpec.")
if spec is not None:
shape = spec.shape
if shape[: self.ndim] != self.shape:
Expand Down Expand Up @@ -2945,6 +2948,70 @@ def unsqueeze(self, dim: int):
device=device,
)

def lock_(self, recurse=False):
"""Locks the CompositeSpec and prevents modification of its content.
This is only a first-level lock, unless specified otherwise through the
``recurse`` arg.
Leaf specs can always be modified in place, but they cannot be replaced
in their CompositeSpec parent.
Examples:
>>> shape = [3, 4, 5]
>>> spec = CompositeSpec(
... a=CompositeSpec(
... b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2]
... ),
... shape=shape[:1],
... )
>>> spec["a"] = spec["a"].clone()
>>> recurse = False
>>> spec.lock_(recurse=recurse)
>>> try:
... spec["a"] = spec["a"].clone()
... except RuntimeError:
... print("failed!")
failed!
>>> try:
... spec["a", "b"] = spec["a", "b"].clone()
... print("succeeded!")
... except RuntimeError:
... print("failed!")
succeeded!
>>> recurse = True
>>> spec.lock_(recurse=recurse)
>>> try:
... spec["a", "b"] = spec["a", "b"].clone()
... print("succeeded!")
... except RuntimeError:
... print("failed!")
failed!
"""
self._locked = True
if recurse:
for value in self.values():
if isinstance(value, CompositeSpec):
value.lock_(recurse)

def unlock_(self, recurse=False):
"""Unlocks the CompositeSpec and allows modification of its content.
This is only a first-level lock modification, unless specified
otherwise through the ``recurse`` arg.
"""
self._locked = False
if recurse:
for value in self.values():
if isinstance(value, CompositeSpec):
value.unlock_(recurse)

@property
def locked(self):
return self._locked


class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec):
"""A lazy representation of a stack of composite specs.
Expand Down

0 comments on commit 71c1657

Please sign in to comment.