Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix,Feature] Allow non-tensor data in envs #1944

Merged
merged 11 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Feb 21, 2024
commit f7bb04b3732b18e7107e74796efcd41c747faba6
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Each env will have the following attributes:
all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`).
It is locked and should not be modified directly.

If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensorSpec`
instance can be used.

Importantly, the environment spec shapes should contain the batch size, e.g.
an environment with :obj:`env.batch_size == torch.Size([4])` should have
an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`.
Expand Down
20 changes: 20 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,26 @@ def make_env(seed, device=device):
p_env.close()


class TestNonTensorEnv:
@pytest.mark.parametrize("bwad", [True, False])
def test_single(self, bwad):
env = EnvWithMetadata()
r = env.rollout(10, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == list(range(10))

@pytest.mark.parametrize("bwad", [True, False])
def test_serial(self, bwad):
env = SerialEnv(2, EnvWithMetadata)
r = env.rollout(10, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(10))] * 2

@pytest.mark.parametrize("bwad", [True, False])
def test_parallel(self, bwad):
env = ParallelEnv(2, EnvWithMetadata)
r = env.rollout(10, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(10))] * 2


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
22 changes: 21 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3973,6 +3973,8 @@ class NonTensorSpec(TensorSpec):
domain: str = "continuous"

def __init__(self, shape=(), device=None):
if device is not None:
device = torch.device(device)
return super().__init__(shape=shape, space=None, device=device, dtype=None)

def rand(self, shape=None) -> torch.Tensor:
Expand All @@ -3982,14 +3984,32 @@ def zero(self, shape=None) -> torch.Tensor:
return NonTensorData(None, batch_size=self.shape)

def expand(self, *shape):
return NonTensorSpec(shape=shape)
if len(shape) == 1 and not isinstance(shape[0], int):
shape = shape[0]
return NonTensorSpec(device=self.device, shape=shape)

def is_in(self, val: torch.Tensor) -> bool:
return True

def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def to(self, dest):
if dest is None:
return self
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(shape=self.shape, device=dest)

def clone(self) -> "TensorSpec":
return self.__class__(shape=self.shape, device=self.device)


class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec):
"""A lazy representation of a stack of composite specs.
Expand Down
129 changes: 110 additions & 19 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

import torch

from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict import LazyStackedTensorDict, NonTensorData, TensorDict, TensorDictBase
from tensordict._tensordict import unravel_key
from tensordict.tensorclass import NonTensorStack
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
_ProcessNoWarn,
logger as torchrl_logger,
VERBOSE,
)
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.tensor_specs import CompositeSpec, NonTensorSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _EnvPostInit, EnvBase
from torchrl.envs.env_creator import get_env_metadata
Expand Down Expand Up @@ -563,12 +564,30 @@ def _create_td(self) -> None:
# output keys after step
self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}

# Non-tensor keys
non_tensor_keys = []
for key, spec in self.observation_spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
for key, spec in self.full_reward_spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
for key, spec in self.full_done_spec.items(True, True):
if isinstance(spec, NonTensorSpec):
non_tensor_keys.append(key)
self._non_tensor_keys = non_tensor_keys

if self._single_task:
shared_tensordict_parent = shared_tensordict_parent.select(
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
strict=False,
)
for key, item in list(shared_tensordict_parent.items(True)):
if isinstance(item, NonTensorData):
shared_tensordict_parent.set(
key, NonTensorStack(*item.unbind(0), stack_dim=0)
)
self.shared_tensordict_parent = shared_tensordict_parent
else:
# Multi-task: we share tensordict that *may* have different keys
Expand Down Expand Up @@ -810,7 +829,6 @@ def select_and_clone(name, tensor):
nested_keys=True,
filter_empty=True,
)

if out.device != device:
if device is None:
out = out.clear_device_()
Expand Down Expand Up @@ -1101,6 +1119,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
"_selected_input_keys": self._selected_input_keys,
"_selected_reset_keys": self._selected_reset_keys,
"_selected_step_keys": self._selected_step_keys,
"_non_tensor_keys": self._non_tensor_keys,
"has_lazy_inputs": self.has_lazy_inputs,
}
)
Expand Down Expand Up @@ -1172,17 +1191,39 @@ def step_and_maybe_reset(
# We keep track of which keys are present to let the worker know what
# should be passd to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
data = [
{"next_td_passthrough_keys": next_td_keys}
for _ in range(self.num_workers)
]
self.shared_tensordict_parent.get("next").update_(next_td_passthrough)
else:
next_td_keys = None
data = [{} for _ in range(self.num_workers)]

if self._non_tensor_keys:
for i in range(self.num_workers):
data[i]["non_tensor_data"] = tensordict[i].select(
*self._non_tensor_keys, strict=False
)

for i in range(self.num_workers):
self.parent_channels[i].send(("step_and_maybe_reset", next_td_keys))
self.parent_channels[i].send(("step_and_maybe_reset", data[i]))

for i in range(self.num_workers):
event = self._events[i]
event.wait()
event.clear()

if self._non_tensor_keys:
for i in range(self.num_workers):
msg, non_tensor_td = self.parent_channels[i].recv()
self.shared_tensordicts[i].update_(
non_tensor_td,
keys_to_update=[
*self._non_tensor_keys,
*[("next", key) for key in self._non_tensor_keys],
],
)

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
next_td = self._shared_tensordict_parent_next
Expand Down Expand Up @@ -1233,21 +1274,38 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We keep track of which keys are present to let the worker know what
# should be passd to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
data = [
{"next_td_passthrough_keys": next_td_keys}
for _ in range(self.num_workers)
]
self.shared_tensordict_parent.get("next").update_(next_td_passthrough)
else:
next_td_keys = None
data = [{} for _ in range(self.num_workers)]

if self._non_tensor_keys:
for i in range(self.num_workers):
data[i]["non_tensor_data"] = tensordict[i].select(
*self._non_tensor_keys, strict=False
)

if self.event is not None:
self.event.record()
self.event.synchronize()
for i in range(self.num_workers):
self.parent_channels[i].send(("step", next_td_keys))
self.parent_channels[i].send(("step", data[i]))

for i in range(self.num_workers):
event = self._events[i]
event.wait()
event.clear()

if self._non_tensor_keys:
for i in range(self.num_workers):
msg, non_tensor_td = self.parent_channels[i].recv()
self.shared_tensordicts[i].get("next").update_(
non_tensor_td, keys_to_update=self._non_tensor_keys
)

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
next_td = self.shared_tensordict_parent.get("next")
Expand Down Expand Up @@ -1336,6 +1394,13 @@ def tentative_update(val, other):
event.wait()
event.clear()

if self._non_tensor_keys:
for i in workers:
msg, non_tensor_td = self.parent_channels[i].recv()
self.shared_tensordicts[i].update_(
non_tensor_td, keys_to_update=self._non_tensor_keys
)

selected_output_keys = self._selected_reset_keys_filt
device = self.device

Expand Down Expand Up @@ -1472,6 +1537,7 @@ def _run_worker_pipe_shared_mem(
_selected_input_keys=None,
_selected_reset_keys=None,
_selected_step_keys=None,
_non_tensor_keys=None,
has_lazy_inputs: bool = False,
verbose: bool = False,
) -> None:
Expand Down Expand Up @@ -1567,26 +1633,42 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
event.record()
event.synchronize()
mp_event.set()

if _non_tensor_keys:
child_pipe.send(
("non_tensor", cur_td.select(*_non_tensor_keys, strict=False))
)

del cur_td

elif cmd == "step":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# No need to copy here since we don't write in-place
input = root_shared_tensordict
if data:
next_td_passthrough_keys = data
input = root_shared_tensordict.set(
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
)
else:
input = root_shared_tensordict
next_td_passthrough_keys = data.get("next_td_passthrough_keys", None)
if next_td_passthrough_keys is not None:
input = input.set(
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
)
non_tensor_data = data.get("non_tensor_data", None)
if non_tensor_data is not None:
input.update_(non_tensor_data)

next_td = env._step(input)
next_shared_tensordict.update_(next_td)
if event is not None:
event.record()
event.synchronize()
mp_event.set()

if _non_tensor_keys:
child_pipe.send(
("non_tensor", next_td.select(*_non_tensor_keys, strict=False))
)

del next_td

elif cmd == "step_and_maybe_reset":
Expand All @@ -1601,13 +1683,16 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
# by output) of the step!
# Caveat: for RNN we may need some keys of the "next" TD so we pass the list
# through data
input = root_shared_tensordict
if data:
next_td_passthrough_keys = data
input = root_shared_tensordict.set(
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
)
else:
input = root_shared_tensordict
next_td_passthrough_keys = data.get("next_td_passthrough_keys", None)
if next_td_passthrough_keys is not None:
input = input.set(
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
)
non_tensor_data = data.get("non_tensor_data", None)
if non_tensor_data is not None:
input.update_(non_tensor_data)
td, root_next_td = env.step_and_maybe_reset(input)
next_shared_tensordict.update_(td.pop("next"))
root_shared_tensordict.update_(root_next_td)
Expand All @@ -1616,6 +1701,12 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
event.record()
event.synchronize()
mp_event.set()

if _non_tensor_keys:
ntd = root_next_td.select(*_non_tensor_keys)
ntd.set("next", next_shared_tensordict.select(*_non_tensor_keys))
child_pipe.send(("non_tensor", ntd))

del td, root_next_td

elif cmd == "close":
Expand Down
7 changes: 4 additions & 3 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
is_tensor_collection,
LazyStackedTensorDict,
NonTensorData,
NonTensorStack,
TensorDictBase,
unravel_key,
)
Expand Down Expand Up @@ -292,7 +293,9 @@ def _set(source, dest, key, total_key, excluded):
if unravel_key(total_key) not in excluded:
try:
val = source.get(key)
if is_tensor_collection(val) and not isinstance(val, NonTensorData):
if is_tensor_collection(val) and not isinstance(
val, (NonTensorData, NonTensorStack)
):
new_val = dest.get(key, None)
if new_val is None:
new_val = val.empty()
Expand Down Expand Up @@ -493,8 +496,6 @@ def check_env_specs(
- List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}.
"""
)
print(torch.zeros_like(fake_tensordict_select))
print(torch.zeros_like(real_tensordict_select))
if (
torch.zeros_like(fake_tensordict_select)
!= torch.zeros_like(real_tensordict_select)
Expand Down
Loading