Skip to content

Commit

Permalink
[BugFix] Fix vecnorm state-dicts (#2158)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 10, 2024
1 parent 7befddc commit e77f0dd
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 26 deletions.
56 changes: 54 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7777,7 +7777,9 @@ def test_vecnorm_parallel_auto(self, nprc):
for idx in range(nprc):
queues[idx][1].put(msg)

td = make_env.state_dict()["transforms.1._extra_state"]["td"]
td = TensorDict(
make_env.state_dict()["transforms.1._extra_state"]
).unflatten_keys(VecNorm.SEP)

obs_sum = td.get(("some", "obs_sum")).clone()
obs_ssq = td.get(("some", "obs_ssq")).clone()
Expand Down Expand Up @@ -7878,7 +7880,9 @@ def test_parallelenv_vecnorm(self):
parallel_sd = parallel_env.state_dict()
assert "worker0" in parallel_sd
worker_sd = parallel_sd["worker0"]
td = worker_sd["transforms.1._extra_state"]["td"]
td = TensorDict(worker_sd["transforms.1._extra_state"]).unflatten_keys(
VecNorm.SEP
)
queue_out.put("start")
msg = queue_in.get(timeout=TIMEOUT)
assert msg == "first round"
Expand Down Expand Up @@ -7952,6 +7956,54 @@ def test_pickable(self):
for key in sorted(transform.__dict__.keys()):
assert isinstance(transform.__dict__[key], type(transform2.__dict__[key]))

def test_state_dict_vecnorm(self):
transform0 = Compose(
VecNorm(in_keys=["a", ("b", "c")], out_keys=["a_avg", ("b", "c_avg")])
)
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
with pytest.warns(UserWarning, match="Querying state_dict on an uninitialized"):
sd_empty = transform0.state_dict()

transform1 = transform0.clone()
# works fine
transform1.load_state_dict(sd_empty)
transform1._step(td, td)
with pytest.raises(KeyError, match="Could not find a tensordict"):
transform1.load_state_dict(sd_empty)

transform0._step(td, td)
sd = transform0.state_dict()

transform1 = transform0.clone()
assert transform0[0]._td.is_shared() is transform1[0]._td.is_shared()

def assert_differs(a, b):
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()

transform1[0]._td.apply(assert_differs, transform0[0]._td, filter_empty=True)

transform1 = Compose(
VecNorm(in_keys=["a", ("b", "c")], out_keys=["a_avg", ("b", "c_avg")])
)
with pytest.warns(UserWarning, match="VecNorm wasn't initialized"):
transform1.load_state_dict(sd)
transform1._step(td, td)

transform1 = Compose(
VecNorm(in_keys=["a", ("b", "c")], out_keys=["a_avg", ("b", "c_avg")])
)
transform1._step(td, td)
transform1.load_state_dict(sd)

def test_to_obsnorm_multikeys(self):
transform0 = Compose(
VecNorm(in_keys=["a", ("b", "c")], out_keys=["a_avg", ("b", "c_avg")])
)
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
td0 = transform0._step(td, td.clone())
td1 = transform0[0].to_observation_norm()._step(td, td.clone())
assert_allclose_td(td0, td1)


def test_added_transforms_are_in_eval_mode_trivial():
base_env = ContinuousActionVecMockEnv()
Expand Down
64 changes: 40 additions & 24 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import collections
import functools
import importlib.util
import multiprocessing as mp
Expand Down Expand Up @@ -4816,6 +4815,8 @@ class VecNorm(Transform):
Defaults to ``in_keys``.
shared_td (TensorDictBase, optional): A shared tensordict containing the
keys of the transform.
lock (mp.Lock): a lock to prevent race conditions between processes.
Defaults to None (lock created during init).
decay (number, optional): decay rate of the moving average.
default: 0.99
eps (number, optional): lower bound of the running standard
Expand Down Expand Up @@ -5006,7 +5007,7 @@ def _update(self, key, value, N) -> torch.Tensor:
def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
"""Converts VecNorm into an ObservationNorm class that can be used at inference time."""
out = []
for key in self.in_keys:
for key, key_out in zip(self.in_keys, self.out_keys):
_sum = self._td.get(_append_last(key, "_sum"))
_ssq = self._td.get(_append_last(key, "_ssq"))
_count = self._td.get(_append_last(key, "_count"))
Expand All @@ -5017,13 +5018,13 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
loc=mean,
scale=std,
standard_normal=True,
in_keys=self.in_keys,
in_keys=key,
out_keys=key_out,
)
if len(self.in_keys) == 1:
return _out
else:
out += ObservationNorm
return Compose(*out)
out += [_out]
if len(self.in_keys) > 1:
return Compose(*out)
return _out

@staticmethod
def build_td_for_shared_vecnorm(
Expand Down Expand Up @@ -5090,29 +5091,44 @@ def build_td_for_shared_vecnorm(
return td_select.memmap_()
return td_select.share_memory_()

# We use a different separator to ensure that keys can have points within them.
SEP = "-<.>-"

def get_extra_state(self) -> OrderedDict:
return collections.OrderedDict({"lock": self.lock, "td": self._td})
if self._td is None:
warnings.warn(
"Querying state_dict on an uninitialized VecNorm transform will "
"return a `None` value for the summary statistics. "
"Loading such a state_dict on an initialized VecNorm will result in "
"an error."
)
return
return self._td.flatten_keys(self.SEP).to_dict()

def set_extra_state(self, state: OrderedDict) -> None:
lock = state["lock"]
if lock is not None:
"""
since locks can't be serialized, we have use cases for stripping them
for example in ParallelEnv, in which case keep the lock we already have
to avoid an updated tensor dict being sent between processes to erase locks
"""
self.lock = lock
td = state["td"]
if td is not None and not td.is_shared():
raise RuntimeError(
"Only shared tensordicts can be set in VecNorm transforms"
)
self._td = td
if state is not None:
td = TensorDict(state).unflatten_keys(self.SEP)
if self._td is None and not td.is_shared():
warnings.warn(
"VecNorm wasn't initialized and the tensordict is not shared. In single "
"process settings, this is ok, but if you need to share the statistics "
"between workers this should require some attention. "
"Make sure that the content of VecNorm is transmitted to the workers "
"after calling load_state_dict and not before, as other workers "
"may not have access to the loaded TensorDict."
)
td.share_memory_()
if self._td is not None:
self._td.update_(td)
else:
self._td = td
elif self._td is not None:
raise KeyError("Could not find a tensordict in the state_dict.")

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(decay={self.decay:4.4f},"
f"eps={self.eps:4.4f}, keys={self.in_keys})"
f"eps={self.eps:4.4f}, in_keys={self.in_keys}, out_keys={self.out_keys})"
)

def __getstate__(self) -> Dict[str, Any]:
Expand Down

0 comments on commit e77f0dd

Please sign in to comment.