Skip to content

Commit

Permalink
[BugFix] Fix non-tensor passage in _StepMDP (pytorch#2262)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 2, 2024
1 parent d3f62d6 commit 79fa8bf
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
51 changes: 50 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
dense_stack_tds,
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import _unravel_key_to_tuple
Expand All @@ -68,6 +69,7 @@
from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
NonTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -84,7 +86,11 @@
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
Transform,
)
from torchrl.envs.utils import (
_StepMDP,
_terminated_or_truncated,
Expand Down Expand Up @@ -3188,6 +3194,49 @@ def test_parallel(self, bwad, use_buffers):
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2

class AddString(Transform):
def __init__(self):
super().__init__()
self._str = "0"

def _call(self, td):
td["string"] = str(int(self._str) + 1)
self._str = td["string"]
return td

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
self._str = "0"
tensordict_reset["string"] = self._str
return tensordict_reset

def transform_observation_spec(self, observation_spec):
observation_spec["string"] = NonTensorSpec(())
return observation_spec

@pytest.mark.parametrize("batched", ["serial", "parallel"])
def test_partial_rest(self, batched):
env0 = lambda: CountingEnv(5).append_transform(self.AddString())
env1 = lambda: CountingEnv(6).append_transform(self.AddString())
if batched == "parallel":
env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
else:
env = SerialEnv(2, [env0, env1])
s = env.reset()
i = 0
for i in range(10): # noqa: B007
s, s_ = env.step_and_maybe_reset(
s.set("action", torch.ones(2, 1, dtype=torch.int))
)
if s.get(("next", "done")).any():
break
s = s_
assert i == 5
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
assert s_["string"] == ["0", "6"]
assert s["next", "string"] == ["6", "6"]


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tds = []
for i, _env in enumerate(self._envs):
if not needs_resetting[i]:
if not self._use_buffers and tensordict is not None:
if out_tds is not None and tensordict is not None:
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
continue
if tensordict is not None:
Expand Down
5 changes: 3 additions & 2 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
set_interaction_mode as set_exploration_mode,
set_interaction_type as set_exploration_type,
)
from tensordict.utils import NestedKey
from tensordict.utils import is_non_tensor, NestedKey
from torch import nn as nn
from torch.utils._pytree import tree_map
from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger
Expand Down Expand Up @@ -254,6 +254,8 @@ def _grab_and_place(
if not _allow_absent_keys:
raise KeyError(f"key {key} not found.")
else:
if is_non_tensor(val):
val = val.clone()
data_out._set_str(
key, val, validated=True, inplace=False, non_blocking=False
)
Expand Down Expand Up @@ -1403,7 +1405,6 @@ def _update_during_reset(
reset = reset.any(-1)
reset = reset.reshape(node.shape)
# node.update(node.where(~reset, other=node_reset, pad=0))

node.where(~reset, other=node_reset, out=node, pad=0)
# node = node.clone()
# idx = reset.nonzero(as_tuple=True)[0]
Expand Down

0 comments on commit 79fa8bf

Please sign in to comment.