diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 0bf81562530..0dca499f4d9 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -902,6 +902,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. UnboundedDiscreteTensorSpec LazyStackedTensorSpec LazyStackedCompositeSpec + NonTensorSpec Reinforcement Learning From Human Feedback (RLHF) ------------------------------------------------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 62959c0b3c4..4c4a1e2b8e5 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -46,6 +46,9 @@ Each env will have the following attributes: all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). It is locked and should not be modified directly. +If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensorSpec` +instance can be used. + Importantly, the environment spec shapes should contain the batch size, e.g. an environment with :obj:`env.batch_size == torch.Size([4])` should have an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`. diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5f660b554a4..ea4327bb460 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -16,6 +16,7 @@ CompositeSpec, DiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, + NonTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, @@ -1825,6 +1826,39 @@ def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) +class EnvWithMetadata(EnvBase): + def __init__(self): + super().__init__() + self.observation_spec = CompositeSpec( + tensor=UnboundedContinuousTensorSpec(3), + non_tensor=NonTensorSpec(shape=()), + ) + self.state_spec = CompositeSpec( + non_tensor=NonTensorSpec(shape=()), + ) + self.reward_spec = UnboundedContinuousTensorSpec(1) + self.action_spec = UnboundedContinuousTensorSpec(1) + + def _reset(self, tensordict): + data = self.observation_spec.zero() + data.set_non_tensor("non_tensor", 0) + data.update(self.full_done_spec.zero()) + return data + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + data = self.observation_spec.zero() + data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1) + data.update(self.full_done_spec.zero()) + data.update(self.full_reward_spec.zero()) + return data + + def _set_seed(self, seed: Optional[int]): + return seed + + class AutoResettingCountingEnv(CountingEnv): def _step(self, tensordict): tensordict = super()._step(tensordict) diff --git a/test/test_env.py b/test/test_env.py index d6ebb16084c..bfda10f0e93 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -43,6 +43,7 @@ DiscreteActionVecMockEnv, DummyModelBasedEnvBase, EnvWithDynamicSpec, + EnvWithMetadata, HeterogeneousCountingEnv, HeterogeneousCountingEnvPolicy, MockBatchedLockedEnv, @@ -2395,6 +2396,7 @@ def test_parallel( @pytest.mark.parametrize( "envclass", [ + EnvWithMetadata, ContinuousActionConvMockEnv, ContinuousActionConvMockEnvNumpy, ContinuousActionVecMockEnv, @@ -2419,6 +2421,7 @@ def test_mocking_envs(envclass): env.set_seed(100) reset = env.reset() _ = env.rand_step(reset) + r = env.rollout(3) check_env_specs(env, seed=100, return_contiguous=False) @@ -3162,6 +3165,30 @@ def test_batched_dynamic(self, break_when_any_done): assert_allclose_td(rollout_no_buffers_serial, rollout_no_buffers_parallel) +class TestNonTensorEnv: + @pytest.mark.parametrize("bwad", [True, False]) + def test_single(self, bwad): + env = EnvWithMetadata() + r = env.rollout(10, break_when_any_done=bwad) + assert r.get("non_tensor").tolist() == list(range(10)) + + @pytest.mark.parametrize("bwad", [True, False]) + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial(self, bwad, use_buffers): + N = 50 + env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers) + r = env.rollout(N, break_when_any_done=bwad) + assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + + @pytest.mark.parametrize("bwad", [True, False]) + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel(self, bwad, use_buffers): + N = 50 + env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers) + r = env.rollout(N, break_when_any_done=bwad) + assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 42fc851d2a6..6681bab4cea 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -643,7 +643,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten indexed tensor """ - raise NotImplementedError + ... @abc.abstractmethod def expand(self, *shape): @@ -656,7 +656,7 @@ def expand(self, *shape): from it if the current dimension is a singleton. """ - raise NotImplementedError + ... def squeeze(self, dim: int | None = None): """Returns a new Spec with all the dimensions of size ``1`` removed. @@ -740,7 +740,7 @@ def is_in(self, val: torch.Tensor) -> bool: boolean indicating if values belongs to the TensorSpec box """ - raise NotImplementedError + ... def contains(self, item): """Returns whether a sample is contained within the space defined by the TensorSpec. @@ -2120,7 +2120,9 @@ def is_in(self, val: torch.Tensor) -> bool: return ( isinstance(val, NonTensorData) and val.shape == shape - and val.device == self.device + # We relax constrains on device as they're hard to enforce for non-tensor + # tensordicts and pointless + # and val.device == self.device and val.dtype == self.dtype ) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b15d3fc4bdf..0f9f8c9a6d7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -35,7 +35,7 @@ logger as torchrl_logger, VERBOSE, ) -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import CompositeSpec, NonTensorSpec from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData from torchrl.envs.env_creator import get_env_metadata @@ -660,6 +660,20 @@ def _create_td(self) -> None: "batched environment base tensordict has the wrong shape" ) + # Non-tensor keys + non_tensor_keys = [] + for spec in ( + self.full_action_spec, + self.full_state_spec, + self.full_observation_spec, + self.full_reward_spec, + self.full_done_spec, + ): + for key, _spec in spec.items(True, True): + if isinstance(_spec, NonTensorSpec): + non_tensor_keys.append(key) + self._non_tensor_keys = non_tensor_keys + if self._single_task: self._env_input_keys = sorted( list(self.input_spec["full_action_spec"].keys(True, True)) @@ -700,6 +714,15 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) + env_obs_keys = [ + key for key in env_obs_keys if key not in self._non_tensor_keys + ] + env_input_keys = [ + key for key in env_input_keys if key not in self._non_tensor_keys + ] + env_output_keys = [ + key for key in env_output_keys if key not in self._non_tensor_keys + ] self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) @@ -727,6 +750,7 @@ def _create_td(self) -> None: self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys} if not self.share_individual_td: + shared_tensordict_parent = shared_tensordict_parent.filter_non_tensor_data() shared_tensordict_parent = shared_tensordict_parent.select( *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), @@ -740,7 +764,7 @@ def _create_td(self) -> None: *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, - ) + ).filter_non_tensor_data() for tensordict in shared_tensordict_parent ] shared_tensordict_parent = LazyStackedTensorDict.lazy_stack( @@ -954,7 +978,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: (self.num_workers,), device=self.device, dtype=torch.bool ) - if not self._use_buffers: + out_tds = None + if not self._use_buffers or self._non_tensor_keys: out_tds = [None] * self.num_workers tds = [] @@ -998,8 +1023,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: "share_individual_td argument to True." ) raise - else: + if out_tds is not None: out_tds[i] = _td + if not self._use_buffers: result = LazyStackedTensorDict.maybe_dense_stack(out_tds) return result @@ -1017,6 +1043,10 @@ def select_and_clone(name, tensor): nested_keys=True, filter_empty=True, ) + if out_tds is not None: + out.update( + LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys + ) if out.device != device: if device is None: @@ -1045,6 +1075,9 @@ def _step( data_in.append(tensordict_in[i]) self._sync_m2w() + out_tds = None + if not self._use_buffers or self._non_tensor_keys: + out_tds = [] if self._use_buffers: next_td = self.shared_tensordict_parent.get("next") @@ -1055,12 +1088,13 @@ def _step( keys_to_update=list(self._env_output_keys), non_blocking=self.non_blocking, ) + if out_tds is not None: + out_tds.append(out_td) else: - tds = [] for i, _data_in in enumerate(data_in): out_td = self._envs[i]._step(_data_in) - tds.append(out_td) - return LazyStackedTensorDict.maybe_dense_stack(tds) + out_tds.append(out_td) + return LazyStackedTensorDict.maybe_dense_stack(out_tds) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -1071,6 +1105,10 @@ def select_and_clone(name, tensor): return tensor.clone() out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) + if out_tds is not None: + out.update( + LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys + ) if out.device != device: if device is None: @@ -1345,6 +1383,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): "_selected_input_keys": self._selected_input_keys, "_selected_reset_keys": self._selected_reset_keys, "_selected_step_keys": self._selected_step_keys, + "_non_tensor_keys": self._non_tensor_keys, } ) process = proc_fun(target=func, kwargs=kwargs[idx]) @@ -1442,20 +1481,38 @@ def step_and_maybe_reset( # We keep track of which keys are present to let the worker know what # should be passd to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) + data = [ + {"next_td_passthrough_keys": next_td_keys} + for _ in range(self.num_workers) + ] self.shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: - next_td_keys = None + # next_td_keys = None + data = [{} for _ in range(self.num_workers)] + + if self._non_tensor_keys: + for i in range(self.num_workers): + data[i]["non_tensor_data"] = tensordict[i].select( + *self._non_tensor_keys, strict=False + ) + self._sync_m2w() for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", next_td_keys)) + self.parent_channels[i].send(("step_and_maybe_reset", data[i])) for i in range(self.num_workers): event = self._events[i] event.wait(self._timeout) event.clear() + if self._non_tensor_keys: + non_tensor_tds = [] + for i in range(self.num_workers): + msg, non_tensor_td = self.parent_channels[i].recv() + non_tensor_tds.append(non_tensor_td) + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self._shared_tensordict_parent_next @@ -1484,6 +1541,13 @@ def step_and_maybe_reset( next_td = next_td.clone().clear_device_() tensordict_ = tensordict_.clone().clear_device_() tensordict.set("next", next_td) + if self._non_tensor_keys: + non_tensor_tds = LazyStackedTensorDict(*non_tensor_tds) + tensordict.update( + non_tensor_tds, + keys_to_update=[("next", key) for key in self._non_tensor_keys], + ) + tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys) return tensordict, tensordict_ def _step_no_buffers( @@ -1524,11 +1588,21 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We keep track of which keys are present to let the worker know what # should be passd to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) + data = [ + {"next_td_passthrough_keys": next_td_keys} + for _ in range(self.num_workers) + ] self.shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: - next_td_keys = None + data = [{} for _ in range(self.num_workers)] + + if self._non_tensor_keys: + for i in range(self.num_workers): + data[i]["non_tensor_data"] = tensordict[i].select( + *self._non_tensor_keys, strict=False + ) self._sync_m2w() @@ -1536,13 +1610,19 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self.event.record() self.event.synchronize() for i in range(self.num_workers): - self.parent_channels[i].send(("step", next_td_keys)) + self.parent_channels[i].send(("step", data[i])) for i in range(self.num_workers): event = self._events[i] event.wait(self._timeout) event.clear() + if self._non_tensor_keys: + non_tensor_tds = [] + for i in range(self.num_workers): + msg, non_tensor_td = self.parent_channels[i].recv() + non_tensor_tds.append(non_tensor_td) + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") @@ -1566,6 +1646,11 @@ def select_and_clone(name, tensor): filter_empty=True, device=device, ) + if self._non_tensor_keys: + out.update( + LazyStackedTensorDict(*non_tensor_tds), + keys_to_update=self._non_tensor_keys, + ) self._sync_w2m() return out @@ -1683,6 +1768,12 @@ def tentative_update(val, other): event.wait(self._timeout) event.clear() + workers_nontensor = [] + if self._non_tensor_keys: + for i, _ in outs: + msg, non_tensor_td = self.parent_channels[i].recv() + workers_nontensor.append((i, non_tensor_td)) + selected_output_keys = self._selected_reset_keys_filt device = self.device @@ -1704,6 +1795,11 @@ def select_and_clone(name, tensor): filter_empty=True, device=device, ) + if self._non_tensor_keys: + workers, nontensor = zip(*workers_nontensor) + out[torch.tensor(workers)] = LazyStackedTensorDict(*nontensor).select( + *self._non_tensor_keys + ) self._sync_w2m() return out @@ -1825,6 +1921,7 @@ def _run_worker_pipe_shared_mem( _selected_input_keys=None, _selected_reset_keys=None, _selected_step_keys=None, + _non_tensor_keys=None, non_blocking: bool = False, has_lazy_inputs: bool = False, verbose: bool = False, @@ -1928,6 +2025,12 @@ def look_for_cuda(tensor, has_cuda=has_cuda): event.record() event.synchronize() mp_event.set() + + if _non_tensor_keys: + child_pipe.send( + ("non_tensor", cur_td.select(*_non_tensor_keys, strict=False)) + ) + del cur_td elif cmd == "step": @@ -1935,19 +2038,29 @@ def look_for_cuda(tensor, has_cuda=has_cuda): raise RuntimeError("called 'init' before step") i += 1 # No need to copy here since we don't write in-place + input = root_shared_tensordict if data: - next_td_passthrough_keys = data - input = root_shared_tensordict.set( - "next", next_shared_tensordict.select(*next_td_passthrough_keys) - ) - else: - input = root_shared_tensordict + next_td_passthrough_keys = data.get("next_td_passthrough_keys") + if next_td_passthrough_keys is not None: + input = input.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + non_tensor_data = data.get("non_tensor_data") + if non_tensor_data is not None: + input.update(non_tensor_data) + next_td = env._step(input) next_shared_tensordict.update_(next_td, non_blocking=non_blocking) if event is not None: event.record() event.synchronize() mp_event.set() + + if _non_tensor_keys: + child_pipe.send( + ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) + ) + del next_td elif cmd == "step_and_maybe_reset": @@ -1962,21 +2075,31 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # by output) of the step! # Caveat: for RNN we may need some keys of the "next" TD so we pass the list # through data + input = root_shared_tensordict if data: - next_td_passthrough_keys = data - input = root_shared_tensordict.set( - "next", next_shared_tensordict.select(*next_td_passthrough_keys) - ) - else: - input = root_shared_tensordict + next_td_passthrough_keys = data.get("next_td_passthrough_keys", None) + if next_td_passthrough_keys is not None: + input = input.set( + "next", next_shared_tensordict.select(*next_td_passthrough_keys) + ) + non_tensor_data = data.get("non_tensor_data", None) + if non_tensor_data is not None: + input.update(non_tensor_data) td, root_next_td = env.step_and_maybe_reset(input) - next_shared_tensordict.update_(td.pop("next"), non_blocking=non_blocking) + td_next = td.pop("next") + next_shared_tensordict.update_(td_next, non_blocking=non_blocking) root_shared_tensordict.update_(root_next_td, non_blocking=non_blocking) if event is not None: event.record() event.synchronize() mp_event.set() + + if _non_tensor_keys: + ntd = root_next_td.select(*_non_tensor_keys) + ntd.set("next", td_next.select(*_non_tensor_keys)) + child_pipe.send(("non_tensor", ntd)) + del td, root_next_td elif cmd == "close": @@ -2059,12 +2182,12 @@ def _run_worker_pipe_direct( env = env_fun del env_fun for spec in env.output_spec.values(True, True): - if spec.device.type == "cuda": + if spec.device is not None and spec.device.type == "cuda": has_cuda = True break else: for spec in env.input_spec.values(True, True): - if spec.device.type == "cuda": + if spec.device is not None and spec.device.type == "cuda": has_cuda = True break else: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e9da917ba03..337b7ef8f9e 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,12 +16,15 @@ from enum import Enum from typing import Any, Dict, List, Union -import tensordict +import tensordict.base + import torch from tensordict import ( is_tensor_collection, LazyStackedTensorDict, + NonTensorData, + NonTensorStack, TensorDict, TensorDictBase, unravel_key, @@ -268,7 +271,7 @@ def _exclude( cls._exclude(nested_key_dict, td, td_out) return out has_set = False - for key, value in data_in.items(is_leaf=tensordict.base._is_leaf_nontensor): + for key, value in data_in.items(is_leaf=_is_leaf_nontensor): subdict = nested_key_dict.get(key, NO_DEFAULT) if subdict is NO_DEFAULT: value = value.copy() if is_tensor_collection(value) else value @@ -562,7 +565,9 @@ def _set(source, dest, key, total_key, excluded): if unravel_key(total_key) not in excluded: try: val = source.get(key) - if is_tensor_collection(val): + if is_tensor_collection(val) and not isinstance( + val, (NonTensorData, NonTensorStack) + ): # if val is a tensordict we need to copy the structure new_val = dest.get(key, None) if new_val is None: @@ -748,11 +753,19 @@ def check_env_specs( [fake_tensordict.clone() for _ in range(3)], -1 ) # eliminate empty containers - fake_tensordict_select = fake_tensordict.select(*fake_tensordict.keys(True, True)) - real_tensordict_select = real_tensordict.select(*real_tensordict.keys(True, True)) + fake_tensordict_select = fake_tensordict.select( + *fake_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + ) + real_tensordict_select = real_tensordict.select( + *real_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + ) # check keys - fake_tensordict_keys = set(fake_tensordict.keys(True, True)) - real_tensordict_keys = set(real_tensordict.keys(True, True)) + fake_tensordict_keys = set( + fake_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + ) + real_tensordict_keys = set( + real_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + ) if fake_tensordict_keys != real_tensordict_keys: raise AssertionError( f"""The keys of the specs and data do not match: