# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, ) from torchrl.data.utils import consolidate_spec from torchrl.envs.common import EnvBase from torchrl.envs.model_based.common import ModelBasedEnvBase spec_dict = { "bounded": BoundedTensorSpec, "one_hot": OneHotDiscreteTensorSpec, "categorical": DiscreteTensorSpec, "unbounded": UnboundedContinuousTensorSpec, "binary": BinaryDiscreteTensorSpec, "mult_one_hot": MultiOneHotDiscreteTensorSpec, "composite": CompositeSpec, } default_spec_kwargs = { OneHotDiscreteTensorSpec: {"n": 7}, DiscreteTensorSpec: {"n": 7}, BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, UnboundedContinuousTensorSpec: { "shape": [ 7, ] }, BinaryDiscreteTensorSpec: {"n": 7}, MultiOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]}, CompositeSpec: {}, } def make_spec(spec_str): target_class = spec_dict[spec_str] return target_class(**default_spec_kwargs[target_class]) class _MockEnv(EnvBase): @classmethod def __new__( cls, *args, **kwargs, ): for key, item in list(cls._output_spec["full_observation_spec"].items()): cls._output_spec["full_observation_spec"][key] = item.to( torch.get_default_dtype() ) reward_spec = cls._output_spec["full_reward_spec"] if isinstance(reward_spec, CompositeSpec): reward_spec = CompositeSpec( { key: item.to(torch.get_default_dtype()) for key, item in reward_spec.items(True, True) }, shape=reward_spec.shape, device=reward_spec.device, ) else: reward_spec = reward_spec.to(torch.get_default_dtype()) cls._output_spec["full_reward_spec"] = reward_spec if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): cls._output_spec["full_reward_spec"] = CompositeSpec( reward=cls._output_spec["full_reward_spec"], shape=cls._output_spec["full_reward_spec"].shape[:-1], ) if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): cls._output_spec["full_done_spec"] = CompositeSpec( done=cls._output_spec["full_done_spec"].clone(), terminated=cls._output_spec["full_done_spec"].clone(), shape=cls._output_spec["full_done_spec"].shape[:-1], ) if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): cls._input_spec["full_action_spec"] = CompositeSpec( action=cls._input_spec["full_action_spec"], shape=cls._input_spec["full_action_spec"].shape[:-1], ) dtype = kwargs.pop("dtype", torch.get_default_dtype()) for spec in (cls._output_spec, cls._input_spec): if dtype != torch.get_default_dtype(): for key, val in list(spec.items(True, True)): if val.dtype == torch.get_default_dtype(): val = val.to(dtype) spec[key] = val return super().__new__(cls, *args, **kwargs) def __init__( self, *args, seed: int = 100, **kwargs, ): super().__init__( device=kwargs.pop("device", "cpu"), dtype=torch.get_default_dtype(), allow_done_after_reset=kwargs.pop("allow_done_after_reset", False), ) self.set_seed(seed) self.is_closed = False @property def maxstep(self): return 100 def _set_seed(self, seed: Optional[int]): self.seed = seed self.counter = seed % 17 # make counter a small number def custom_fun(self): return 0 custom_attr = 1 @property def custom_prop(self): return 2 @property def custom_td(self): return TensorDict({"a": torch.zeros(3)}, []) class MockSerialEnv(EnvBase): """A simple counting env that is reset after a predifined max number of steps.""" @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: action_spec = UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ) if observation_spec is None: observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ), shape=batch_size, ) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ) if done_spec is None: done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: state_spec = CompositeSpec(shape=batch_size) input_spec = CompositeSpec( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size ) cls._output_spec = CompositeSpec(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = input_spec if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): cls._output_spec["full_reward_spec"] = CompositeSpec( reward=cls._output_spec["full_reward_spec"], shape=batch_size ) if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): cls._output_spec["full_done_spec"] = CompositeSpec( done=cls._output_spec["full_done_spec"], shape=batch_size ) if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): cls._input_spec["full_action_spec"] = CompositeSpec( action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(*args, **kwargs) def __init__(self, device="cpu"): super(MockSerialEnv, self).__init__(device=device) self.is_closed = False def _set_seed(self, seed: Optional[int]): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number self.max_val = max(self.counter + 100, self.counter * 2) def _step(self, tensordict): self.counter += 1 n = torch.tensor( [self.counter], device=self.device, dtype=torch.get_default_dtype() ) done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict( { "reward": n, "done": done, "terminated": done.clone(), "observation": n.clone(), }, batch_size=[], ) def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: self.max_val = max(self.counter + 100, self.counter * 2) n = torch.tensor( [self.counter], device=self.device, dtype=torch.get_default_dtype() ) done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict( {"done": done, "terminated": done.clone(), "observation": n}, [] ) def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: return self.step(tensordict) class MockBatchedLockedEnv(EnvBase): """Mocks an env whose batch_size defines the size of the output tensordict""" @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: action_spec = UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ) if state_spec is None: state_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ), shape=batch_size, ) if observation_spec is None: observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ), shape=batch_size, ) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec( ( *batch_size, 1, ) ) if done_spec is None: done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) cls._output_spec = CompositeSpec(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = CompositeSpec( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, ) if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): cls._output_spec["full_reward_spec"] = CompositeSpec( reward=cls._output_spec["full_reward_spec"], shape=batch_size ) if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): cls._output_spec["full_done_spec"] = CompositeSpec( done=cls._output_spec["full_done_spec"], shape=batch_size ) if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): cls._input_spec["full_action_spec"] = CompositeSpec( action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(cls, *args, **kwargs) def __init__(self, device="cpu", batch_size=None): super(MockBatchedLockedEnv, self).__init__(device=device, batch_size=batch_size) self.counter = 0 rand_step = MockSerialEnv.rand_step def _set_seed(self, seed: Optional[int]): assert seed >= 1 self.seed = seed self.counter = seed % 17 # make counter a small number self.max_val = max(self.counter + 100, self.counter * 2) def _step(self, tensordict): if len(self.batch_size): leading_batch_size = ( tensordict.shape[: -len(self.batch_size)] if tensordict is not None else [] ) else: leading_batch_size = tensordict.shape if tensordict is not None else [] self.counter += 1 # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv n = ( torch.full( [*leading_batch_size, *self.observation_spec["observation"].shape], self.counter, ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val done = torch.full( (*leading_batch_size, *self.batch_size, 1), done, dtype=torch.bool, device=self.device, ) return TensorDict( {"reward": n, "done": done, "terminated": done.clone(), "observation": n}, batch_size=tensordict.batch_size, device=self.device, ) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self.max_val = max(self.counter + 100, self.counter * 2) batch_size = self.batch_size if len(batch_size): leading_batch_size = ( tensordict.shape[: -len(self.batch_size)] if tensordict is not None else [] ) else: leading_batch_size = tensordict.shape if tensordict is not None else [] n = ( torch.full( [*leading_batch_size, *self.observation_spec["observation"].shape], self.counter, ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val done = torch.full( (*leading_batch_size, *batch_size, 1), done, dtype=torch.bool, device=self.device, ) return TensorDict( {"done": done, "terminated": done.clone(), "observation": n}, [ *leading_batch_size, *batch_size, ], device=self.device, ) class MockBatchedUnLockedEnv(MockBatchedLockedEnv): """Mocks an env whose batch_size does not define the size of the output tensordict. The size of the output tensordict is defined by the input tensordict itself. """ def __init__(self, device="cpu", batch_size=None): super(MockBatchedUnLockedEnv, self).__init__( batch_size=batch_size, device=device ) @classmethod def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, _batch_locked=False, **kwargs) class DiscreteActionVecMockEnv(_MockEnv): @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, from_pixels=False, categorical_action_encoding=False, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, size]) ), observation_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, size]) ), shape=batch_size, ) if action_spec is None: if categorical_action_encoding: action_spec_cls = DiscreteTensorSpec action_spec = action_spec_cls(n=7, shape=batch_size) else: action_spec_cls = OneHotDiscreteTensorSpec action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: reward_spec = CompositeSpec( reward=UnboundedContinuousTensorSpec(shape=(1,)) ) if done_spec is None: done_spec = CompositeSpec( terminated=DiscreteTensorSpec( 2, dtype=torch.bool, shape=(*batch_size, 1) ) ) if state_spec is None: cls._out_key = "observation_orig" state_spec = CompositeSpec( { cls._out_key: observation_spec["observation"], }, shape=batch_size, ) cls._output_spec = CompositeSpec(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = CompositeSpec( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, ) cls.from_pixels = from_pixels cls.categorical_action_encoding = categorical_action_encoding return super().__new__(*args, **kwargs) def _get_in_obs(self, obs): return obs def _get_out_obs(self, obs): return obs def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase: self.counter += 1 state = torch.zeros(self.size) + self.counter if tensordict is None: tensordict = TensorDict({}, self.batch_size, device=self.device) tensordict = tensordict.empty().set(self.out_key, self._get_out_obs(state)) tensordict = tensordict.set(self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) tensordict.set( "terminated", torch.zeros(*tensordict.shape, 1, dtype=torch.bool) ) return tensordict def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: tensordict = tensordict.to(self.device) a = tensordict.get("action") if not self.categorical_action_encoding: assert (a.sum(-1) == 1).all() obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep tensordict = tensordict.empty() tensordict.set(self.out_key, self._get_out_obs(obs)) tensordict.set(self._out_key, self._get_out_obs(obs)) done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) reward = done.any(-1).unsqueeze(-1) # set done to False done = torch.zeros_like(done).all(-1).unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) tensordict.set("terminated", done.clone()) return tensordict class ContinuousActionVecMockEnv(_MockEnv): @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, from_pixels=False, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, size]) ), observation_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, size]) ), shape=batch_size, ) if action_spec is None: action_spec = BoundedTensorSpec( -1, 1, ( *batch_size, 7, ), ) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) if done_spec is None: done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "observation_orig" state_spec = CompositeSpec( { cls._out_key: observation_spec["observation"], }, shape=batch_size, ) cls._output_spec = CompositeSpec(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = CompositeSpec( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, ) cls.from_pixels = from_pixels return super().__new__(cls, *args, **kwargs) def _get_in_obs(self, obs): return obs def _get_out_obs(self, obs): return obs def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: self.counter += 1 self.step_count = 0 # state = torch.zeros(self.size) + self.counter if tensordict is None: tensordict = TensorDict({}, self.batch_size, device=self.device) tensordict = tensordict.empty() tensordict.update(self.observation_spec.rand()) # tensordict.set("next_" + self.out_key, self._get_out_obs(state)) # tensordict.set("next_" + self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) tensordict.set( "terminated", torch.zeros(*tensordict.shape, 1, dtype=torch.bool) ) return tensordict def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: self.step_count += 1 tensordict = tensordict.to(self.device) a = tensordict.get("action") obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a) tensordict = tensordict.empty() # empty tensordict tensordict.set(self.out_key, self._get_out_obs(obs)) tensordict.set(self._out_key, self._get_out_obs(obs)) done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) while done.shape != tensordict.shape: done = done.any(-1) done = reward = done.unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) tensordict.set("terminated", done) return tensordict def _obs_step(self, obs, a): return obs + a / self.maxstep class DiscreteActionVecPolicy: in_keys = ["observation"] out_keys = ["action"] def _get_in_obs(self, tensordict): obs = tensordict.get(*self.in_keys) return obs def __call__(self, tensordict): obs = self._get_in_obs(tensordict) max_obs = (obs == obs.max(dim=-1, keepdim=True)[0]).cumsum(-1).argmax(-1) k = tensordict.get(*self.in_keys).shape[-1] max_obs = (max_obs + 1) % k action = torch.nn.functional.one_hot(max_obs, k) tensordict.set(*self.out_keys, action) return tensordict class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv): @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, from_pixels=True, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( pixels=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, 1, 7, 7]) ), pixels_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, 1, 7, 7]) ), shape=batch_size, ) if action_spec is None: action_spec = OneHotDiscreteTensorSpec(7, shape=(*batch_size, 7)) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) if done_spec is None: done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "pixels_orig" state_spec = CompositeSpec( { cls._out_key: observation_spec["pixels_orig"].clone(), }, shape=batch_size, ) return super().__new__( *args, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, state_spec=state_spec, from_pixels=from_pixels, done_spec=done_spec, **kwargs, ) def _get_out_obs(self, obs): obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0) return obs def _get_in_obs(self, obs): return obs.diagonal(0, -1, -2).squeeze() class DiscreteActionConvMockEnvNumpy(DiscreteActionConvMockEnv): @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, from_pixels=True, categorical_action_encoding=False, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( pixels=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, 7, 7, 3]) ), pixels_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, 7, 7, 3]) ), shape=batch_size, ) if action_spec is None: action_spec_cls = ( DiscreteTensorSpec if categorical_action_encoding else OneHotDiscreteTensorSpec ) action_spec = action_spec_cls(7, shape=(*batch_size, 7)) if state_spec is None: cls._out_key = "pixels_orig" state_spec = CompositeSpec( { cls._out_key: observation_spec["pixels_orig"], }, shape=batch_size, ) return super().__new__( *args, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, state_spec=state_spec, from_pixels=from_pixels, categorical_action_encoding=categorical_action_encoding, **kwargs, ) def _get_out_obs(self, obs): obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(-1) obs = obs.expand(*obs.shape[:-1], 3) return obs def _get_in_obs(self, obs): return obs.diagonal(0, -2, -3)[..., 0, :] def _obs_step(self, obs, a): return obs + a.unsqueeze(-1) / self.maxstep class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv): @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, from_pixels=True, pixel_shape=None, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if pixel_shape is None: pixel_shape = [1, 7, 7] if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( pixels=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, *pixel_shape]) ), pixels_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, *pixel_shape]) ), shape=batch_size, ) if action_spec is None: action_spec = BoundedTensorSpec(-1, 1, [*batch_size, pixel_shape[-1]]) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) if done_spec is None: done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "pixels_orig" state_spec = CompositeSpec( {cls._out_key: observation_spec["pixels"]}, shape=batch_size ) return super().__new__( *args, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, from_pixels=from_pixels, state_spec=state_spec, done_spec=done_spec, **kwargs, ) def _get_out_obs(self, obs): obs = torch.diag_embed(obs, 0, -2, -1) return obs def _get_in_obs(self, obs): obs = obs.diagonal(0, -1, -2) return obs class ContinuousActionConvMockEnvNumpy(ContinuousActionConvMockEnv): @classmethod def __new__( cls, *args, observation_spec=None, action_spec=None, state_spec=None, reward_spec=None, done_spec=None, from_pixels=True, **kwargs, ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( pixels=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, 7, 7, 3]) ), pixels_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, 7, 7, 3]) ), ) return super().__new__( *args, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, state_spec=state_spec, from_pixels=from_pixels, **kwargs, ) def _get_out_obs(self, obs): obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(-1) obs = obs.expand(*obs.shape[:-1], 3) return obs def _get_in_obs(self, obs): return obs.diagonal(0, -2, -3)[..., 0, :] def _obs_step(self, obs, a): return obs + a / self.maxstep class DiscreteActionConvPolicy(DiscreteActionVecPolicy): in_keys = ["pixels"] out_keys = ["action"] def _get_in_obs(self, tensordict): obs = tensordict.get(*self.in_keys).diagonal(0, -1, -2).squeeze() return obs class DummyModelBasedEnvBase(ModelBasedEnvBase): """Dummy environnement for Model Based RL algorithms. This class is meant to be used to test the model based environnement. Args: world_model (WorldModel): the world model to use for the environnement. device (str or torch.device, optional): the device to use for the environnement. dtype (torch.dtype, optional): the dtype to use for the environnement. batch_size (sequence of int, optional): the batch size to use for the environnement. """ def __init__( self, world_model, device="cpu", dtype=None, batch_size=None, ): super().__init__( world_model, device=device, dtype=dtype, batch_size=batch_size, ) self.observation_spec = CompositeSpec( hidden_observation=UnboundedContinuousTensorSpec( ( *self.batch_size, 4, ) ), shape=self.batch_size, ) self.state_spec = CompositeSpec( hidden_observation=UnboundedContinuousTensorSpec( ( *self.batch_size, 4, ) ), shape=self.batch_size, ) self.action_spec = UnboundedContinuousTensorSpec( ( *self.batch_size, 1, ) ) self.reward_spec = UnboundedContinuousTensorSpec( ( *self.batch_size, 1, ) ) def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: td = TensorDict( { "hidden_observation": self.state_spec["hidden_observation"].rand(), }, batch_size=self.batch_size, device=self.device, ) return td class ActionObsMergeLinear(nn.Module): def __init__(self, in_size, out_size): super().__init__() self.linear = nn.Linear(in_size, out_size) def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) class CountingEnvCountPolicy: def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): self.action_spec = action_spec self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.set(self.action_key, self.action_spec.zero() + 1) class CountingEnvCountModule(nn.Module): def __init__(self, action_spec: TensorSpec): super().__init__() self.action_spec = action_spec def forward(self, t): return self.action_spec.zero() + 1 class CountingEnv(EnvBase): """An env that is done after a given number of steps. The action is the count increment. """ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): super().__init__(**kwargs) self.max_steps = max_steps self.start_val = start_val self.observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( ( *self.batch_size, 1, ), dtype=torch.int32, device=self.device, ), shape=self.batch_size, device=self.device, ) self.reward_spec = UnboundedContinuousTensorSpec( ( *self.batch_size, 1, ), device=self.device, ) self.done_spec = DiscreteTensorSpec( 2, dtype=torch.bool, shape=( *self.batch_size, 1, ), device=self.device, ) self.action_spec = BinaryDiscreteTensorSpec( n=1, shape=[*self.batch_size, 1], device=self.device ) self.register_buffer( "count", torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), ) def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict is not None and "_reset" in tensordict.keys(): _reset = tensordict.get("_reset") self.count[_reset] = self.start_val else: self.count[:] = self.start_val return TensorDict( source={ "observation": self.count.clone(), "done": self.count > self.max_steps, "terminated": self.count > self.max_steps, }, batch_size=self.batch_size, device=self.device, ) def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) self.count += action.to(torch.int).to(self.device) tensordict = TensorDict( source={ "observation": self.count.clone(), "done": self.count > self.max_steps, "terminated": self.count > self.max_steps, "reward": torch.zeros_like(self.count, dtype=torch.float), }, batch_size=self.batch_size, device=self.device, ) return tensordict class IncrementingEnv(CountingEnv): # Same as CountingEnv but always increments the count by 1 regardless of the action. def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: self.count += 1 # The only difference with CountingEnv. tensordict = TensorDict( source={ "observation": self.count.clone(), "done": self.count > self.max_steps, "terminated": self.count > self.max_steps, "reward": torch.zeros_like(self.count, dtype=torch.float), }, batch_size=self.batch_size, device=self.device, ) return tensordict class NestedCountingEnv(CountingEnv): # an env with nested reward and done states def __init__( self, max_steps: int = 5, start_val: int = 0, nest_obs_action: bool = True, nest_done: bool = True, nest_reward: bool = True, nested_dim: int = 3, has_root_done: bool = False, **kwargs, ): super().__init__(max_steps=max_steps, start_val=start_val, **kwargs) self.nested_dim = nested_dim self.has_root_done = has_root_done self.nested_obs_action = nest_obs_action self.nested_done = nest_done self.nested_reward = nest_reward if self.nested_obs_action: self.observation_spec = CompositeSpec( { "data": CompositeSpec( { "states": self.observation_spec["observation"] .unsqueeze(-1) .expand(*self.batch_size, self.nested_dim, 1) }, shape=( *self.batch_size, self.nested_dim, ), ) }, shape=self.batch_size, ) self.action_spec = CompositeSpec( { "data": CompositeSpec( { "action": self.action_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim, 1 ) }, shape=( *self.batch_size, self.nested_dim, ), ) }, shape=self.batch_size, ) if self.nested_reward: self.reward_spec = CompositeSpec( { "data": CompositeSpec( { "reward": self.reward_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim, 1 ) }, shape=( *self.batch_size, self.nested_dim, ), ) }, shape=self.batch_size, ) if self.nested_done: done_spec = self.full_done_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim ) done_spec = CompositeSpec( {"data": done_spec}, shape=self.batch_size, ) if self.has_root_done: done_spec["done"] = DiscreteTensorSpec( 2, shape=( *self.batch_size, 1, ), dtype=torch.bool, ) self.done_spec = done_spec def _reset(self, tensordict): # check that reset works as expected if tensordict is not None: if self.nested_done: if not self.has_root_done: assert "_reset" not in tensordict.keys() else: assert ("data", "_reset") not in tensordict.keys(True) tensordict_reset = super()._reset(tensordict) if self.nested_done: for done_key in self.done_keys: if isinstance(done_key, str): continue if self.has_root_done: done = tensordict_reset.get(done_key[-1]) else: done = tensordict_reset.pop(done_key[-1]) tensordict_reset.set( done_key, (done.unsqueeze(-2).expand(*self.batch_size, self.nested_dim, 1)), ) if self.nested_obs_action: obs = tensordict_reset.pop("observation") tensordict_reset.set( ("data", "states"), (obs.unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)), ) if "data" in tensordict_reset.keys(): tensordict_reset.get("data").batch_size = ( *self.batch_size, self.nested_dim, ) return tensordict_reset def _step(self, tensordict): if self.nested_obs_action: tensordict = tensordict.clone() tensordict["data"].batch_size = self.batch_size tensordict[self.action_key] = tensordict[self.action_key].max(-2)[0] next_tensordict = super()._step(tensordict) if self.nested_obs_action: tensordict[self.action_key] = ( tensordict[self.action_key] .unsqueeze(-1) .expand(*self.batch_size, self.nested_dim, 1) ) if "data" in tensordict.keys(): tensordict["data"].batch_size = (*self.batch_size, self.nested_dim) if self.nested_done: for done_key in self.done_keys: if isinstance(done_key, str): continue if self.has_root_done: done = next_tensordict.get(done_key[-1]) else: done = next_tensordict.pop(done_key[-1]) next_tensordict.set( done_key, (done.unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)), ) if self.nested_obs_action: next_tensordict.set( ("data", "states"), ( next_tensordict.pop("observation") .unsqueeze(-1) .expand(*self.batch_size, self.nested_dim, 1) ), ) if self.nested_reward: next_tensordict.set( self.reward_key, ( next_tensordict.pop("reward") .unsqueeze(-1) .expand(*self.batch_size, self.nested_dim, 1) ), ) if "data" in next_tensordict.keys(): next_tensordict.get("data").batch_size = (*self.batch_size, self.nested_dim) return next_tensordict class CountingBatchedEnv(EnvBase): """An env that is done after a given number of steps. The action is the count increment. Unlike ``CountingEnv``, different envs of the batch can have different max_steps """ def __init__( self, max_steps: torch.Tensor = None, start_val: torch.Tensor = None, **kwargs, ): super().__init__(**kwargs) if max_steps is None: max_steps = torch.tensor(5) if start_val is None: start_val = torch.zeros((), dtype=torch.int32) if not max_steps.shape == self.batch_size: raise RuntimeError("batch_size and max_steps shape must match.") self.max_steps = max_steps self.observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( ( *self.batch_size, 1, ), dtype=torch.int32, ), shape=self.batch_size, ) self.reward_spec = UnboundedContinuousTensorSpec( ( *self.batch_size, 1, ) ) self.done_spec = DiscreteTensorSpec( 2, dtype=torch.bool, shape=( *self.batch_size, 1, ), ) self.action_spec = BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]) self.count = torch.zeros( (*self.batch_size, 1), device=self.device, dtype=torch.int ) if start_val.numel() == self.batch_size.numel(): self.start_val = start_val.view(*self.batch_size, 1) elif start_val.numel() <= 1: self.start_val = start_val.expand_as(self.count) def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict is not None and "_reset" in tensordict.keys(): _reset = tensordict.get("_reset") self.count[_reset] = self.start_val[_reset].view_as(self.count[_reset]) else: self.count[:] = self.start_val.view_as(self.count) return TensorDict( source={ "observation": self.count.clone(), "done": self.count > self.max_steps.view_as(self.count), "terminated": self.count > self.max_steps.view_as(self.count), }, batch_size=self.batch_size, device=self.device, ) def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get("action") self.count += action.to(torch.int).view_as(self.count) tensordict = TensorDict( source={ "observation": self.count.clone(), "done": self.count > self.max_steps.unsqueeze(-1), "terminated": self.count > self.max_steps.unsqueeze(-1), "reward": torch.zeros_like(self.count, dtype=torch.float), }, batch_size=self.batch_size, device=self.device, ) return tensordict class HeteroCountingEnvPolicy: def __init__(self, full_action_spec: TensorSpec, count: bool = True): self.full_action_spec = full_action_spec self.count = count def __call__(self, td: TensorDictBase) -> TensorDictBase: action_td = self.full_action_spec.zero() if self.count: action_td.apply_(lambda x: x + 1) return td.update(action_td) class HeteroCountingEnv(EnvBase): """A heterogeneous, counting Env.""" def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): super().__init__(**kwargs) self.n_nested_dim = 3 self.max_steps = max_steps self.start_val = start_val count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int) count[:] = self.start_val self.register_buffer("count", count) obs_specs = [] action_specs = [] for index in range(self.n_nested_dim): obs_specs.append(self.get_agent_obs_spec(index)) action_specs.append(self.get_agent_action_spec(index)) obs_specs = torch.stack(obs_specs, dim=0) obs_spec_unlazy = consolidate_spec(obs_specs) action_specs = torch.stack(action_specs, dim=0) self.unbatched_observation_spec = CompositeSpec( lazy=obs_spec_unlazy, state=UnboundedContinuousTensorSpec( shape=( 64, 64, 3, ) ), ) self.unbatched_action_spec = CompositeSpec( lazy=action_specs, ) self.unbatched_reward_spec = CompositeSpec( { "lazy": CompositeSpec( { "reward": UnboundedContinuousTensorSpec( shape=(self.n_nested_dim, 1) ) }, shape=(self.n_nested_dim,), ) } ) self.unbatched_done_spec = CompositeSpec( { "lazy": CompositeSpec( { "done": DiscreteTensorSpec( n=2, shape=(self.n_nested_dim, 1), dtype=torch.bool, ), }, shape=(self.n_nested_dim,), ) } ) self.action_spec = self.unbatched_action_spec.expand( *self.batch_size, *self.unbatched_action_spec.shape ) self.observation_spec = self.unbatched_observation_spec.expand( *self.batch_size, *self.unbatched_observation_spec.shape ) self.reward_spec = self.unbatched_reward_spec.expand( *self.batch_size, *self.unbatched_reward_spec.shape ) self.done_spec = self.unbatched_done_spec.expand( *self.batch_size, *self.unbatched_done_spec.shape ) def get_agent_obs_spec(self, i): camera = BoundedTensorSpec(low=0, high=200, shape=(7, 7, 3)) vector_3d = UnboundedContinuousTensorSpec(shape=(3,)) vector_2d = UnboundedContinuousTensorSpec(shape=(2,)) lidar = BoundedTensorSpec(low=0, high=5, shape=(8,)) tensor_0 = UnboundedContinuousTensorSpec(shape=(1,)) tensor_1 = BoundedTensorSpec(low=0, high=3, shape=(1, 2)) tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3)) if i == 0: return CompositeSpec( { "camera": camera, "lidar": lidar, "vector": vector_3d, "tensor_0": tensor_0, } ) elif i == 1: return CompositeSpec( { "camera": camera, "lidar": lidar, "vector": vector_2d, "tensor_1": tensor_1, } ) elif i == 2: return CompositeSpec( { "camera": camera, "vector": vector_2d, "tensor_2": tensor_2, } ) else: raise ValueError(f"Index {i} undefined for index 3") def get_agent_action_spec(self, i): action_3d = BoundedTensorSpec(low=-1, high=1, shape=(3,)) action_2d = BoundedTensorSpec(low=-1, high=1, shape=(2,)) # Some have 2d action and some 3d # TODO Introduce composite heterogeneous actions if i == 0: ret = action_3d elif i == 1: ret = action_2d elif i == 2: ret = action_2d else: raise ValueError(f"Index {i} undefined for index 3") return CompositeSpec({"action": ret}) def _reset( self, tensordict: TensorDictBase = None, **kwargs, ) -> TensorDictBase: if tensordict is not None and self.reset_keys[0] in tensordict.keys(True): _reset = tensordict.get(self.reset_keys[0]).squeeze(-1).any(-1) self.count[_reset] = self.start_val else: self.count[:] = self.start_val reset_td = self.observation_spec.zero() reset_td.apply_(lambda x: x + expand_right(self.count, x.shape)) reset_td.update(self.output_spec["full_done_spec"].zero()) assert reset_td.batch_size == self.batch_size for key in reset_td.keys(True): assert "_reset" not in key return reset_td def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: actions = torch.zeros_like(self.count.squeeze(-1), dtype=torch.bool) for i in range(self.n_nested_dim): action = tensordict["lazy"][..., i]["action"] action = action[..., 0].to(torch.bool) actions += action self.count += actions.unsqueeze(-1).to(torch.int) td = self.observation_spec.zero() td.apply_(lambda x: x + expand_right(self.count, x.shape)) td.update(self.output_spec["full_done_spec"].zero()) td.update(self.output_spec["full_reward_spec"].zero()) assert td.batch_size == self.batch_size for done_key in self.done_keys: td[done_key] = expand_right( self.count > self.max_steps, self.full_done_spec[done_key].shape, ) return td def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) class MultiKeyCountingEnvPolicy: def __init__( self, full_action_spec: TensorSpec, count: bool = True, deterministic: bool = False, ): if not deterministic and not count: raise ValueError("Not counting policy is always deterministic") self.full_action_spec = full_action_spec self.count = count self.deterministic = deterministic def __call__(self, td: TensorDictBase) -> TensorDictBase: action_td = self.full_action_spec.zero() if self.count: if self.deterministic: action_td["nested_1", "action"] += 1 action_td["nested_2", "azione"] += 1 action_td["action"][..., 1] = 1 else: # We choose an action at random choice = torch.randint(0, 3, ()).item() if choice == 0: action_td["nested_1", "action"] += 1 elif choice == 1: action_td["nested_2", "azione"] += 1 else: action_td["action"][..., 1] = 1 return td.update(action_td) class MultiKeyCountingEnv(EnvBase): def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): super().__init__(**kwargs) self.max_steps = max_steps self.start_val = start_val self.nested_dim_1 = 3 self.nested_dim_2 = 2 count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int) count_nested_1 = torch.zeros( (*self.batch_size, self.nested_dim_1, 1), device=self.device, dtype=torch.int, ) count_nested_2 = torch.zeros( (*self.batch_size, self.nested_dim_2, 1), device=self.device, dtype=torch.int, ) count[:] = self.start_val count_nested_1[:] = self.start_val count_nested_2[:] = self.start_val self.register_buffer("count", count) self.register_buffer("count_nested_1", count_nested_1) self.register_buffer("count_nested_2", count_nested_2) self.make_specs() self.action_spec = self.unbatched_action_spec.expand( *self.batch_size, *self.unbatched_action_spec.shape ) self.observation_spec = self.unbatched_observation_spec.expand( *self.batch_size, *self.unbatched_observation_spec.shape ) self.reward_spec = self.unbatched_reward_spec.expand( *self.batch_size, *self.unbatched_reward_spec.shape ) self.done_spec = self.unbatched_done_spec.expand( *self.batch_size, *self.unbatched_done_spec.shape ) def make_specs(self): self.unbatched_observation_spec = CompositeSpec( nested_1=CompositeSpec( observation=BoundedTensorSpec( low=0, high=200, shape=(self.nested_dim_1, 3) ), shape=(self.nested_dim_1,), ), nested_2=CompositeSpec( observation=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 2)), shape=(self.nested_dim_2,), ), observation=UnboundedContinuousTensorSpec( shape=( 10, 10, 3, ) ), ) self.unbatched_action_spec = CompositeSpec( nested_1=CompositeSpec( action=DiscreteTensorSpec(n=2, shape=(self.nested_dim_1,)), shape=(self.nested_dim_1,), ), nested_2=CompositeSpec( azione=BoundedTensorSpec(low=0, high=100, shape=(self.nested_dim_2, 1)), shape=(self.nested_dim_2,), ), action=OneHotDiscreteTensorSpec(n=2), ) self.unbatched_reward_spec = CompositeSpec( nested_1=CompositeSpec( gift=UnboundedContinuousTensorSpec(shape=(self.nested_dim_1, 1)), shape=(self.nested_dim_1,), ), nested_2=CompositeSpec( reward=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 1)), shape=(self.nested_dim_2,), ), reward=UnboundedContinuousTensorSpec(shape=(1,)), ) self.unbatched_done_spec = CompositeSpec( nested_1=CompositeSpec( done=DiscreteTensorSpec( n=2, shape=(self.nested_dim_1, 1), dtype=torch.bool, ), terminated=DiscreteTensorSpec( n=2, shape=(self.nested_dim_1, 1), dtype=torch.bool, ), shape=(self.nested_dim_1,), ), nested_2=CompositeSpec( done=DiscreteTensorSpec( n=2, shape=(self.nested_dim_2, 1), dtype=torch.bool, ), terminated=DiscreteTensorSpec( n=2, shape=(self.nested_dim_2, 1), dtype=torch.bool, ), shape=(self.nested_dim_2,), ), # done at the root always prevail done=DiscreteTensorSpec( n=2, shape=(1,), dtype=torch.bool, ), terminated=DiscreteTensorSpec( n=2, shape=(1,), dtype=torch.bool, ), ) def _reset( self, tensordict: TensorDictBase = None, **kwargs, ) -> TensorDictBase: reset_all = False if tensordict is not None: _reset = tensordict.get("_reset", None) if _reset is not None: self.count[_reset] = self.start_val self.count_nested_1[_reset] = self.start_val self.count_nested_2[_reset] = self.start_val else: reset_all = True if tensordict is None or reset_all: self.count[:] = self.start_val self.count_nested_1[:] = self.start_val self.count_nested_2[:] = self.start_val reset_td = self.observation_spec.zero() reset_td["observation"] += expand_right( self.count, reset_td["observation"].shape ) reset_td["nested_1", "observation"] += expand_right( self.count_nested_1, reset_td["nested_1", "observation"].shape ) reset_td["nested_2", "observation"] += expand_right( self.count_nested_2, reset_td["nested_2", "observation"].shape ) reset_td.update(self.output_spec["full_done_spec"].zero()) assert reset_td.batch_size == self.batch_size return reset_td def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: # Each action has a corresponding reward, done, and observation reward = self.output_spec["full_reward_spec"].zero() done = self.output_spec["full_done_spec"].zero() td = self.observation_spec.zero() one_hot_action = tensordict["action"] one_hot_action = one_hot_action.long().argmax(-1).unsqueeze(-1) reward["reward"] += one_hot_action.to(torch.float) self.count += one_hot_action.to(torch.int) td["observation"] += expand_right(self.count, td["observation"].shape) done["done"] = self.count > self.max_steps done["terminated"] = self.count > self.max_steps discrete_action = tensordict["nested_1"]["action"].unsqueeze(-1) reward["nested_1"]["gift"] += discrete_action.to(torch.float) self.count_nested_1 += discrete_action.to(torch.int) td["nested_1", "observation"] += expand_right( self.count_nested_1, td["nested_1", "observation"].shape ) done["nested_1", "done"] = self.count_nested_1 > self.max_steps done["nested_1", "terminated"] = self.count_nested_1 > self.max_steps continuous_action = tensordict["nested_2"]["azione"] reward["nested_2"]["reward"] += continuous_action.to(torch.float) self.count_nested_2 += continuous_action.to(torch.bool) td["nested_2", "observation"] += expand_right( self.count_nested_2, td["nested_2", "observation"].shape ) done["nested_2", "done"] = self.count_nested_2 > self.max_steps done["nested_2", "terminated"] = self.count_nested_2 > self.max_steps td.update(done) td.update(reward) assert td.batch_size == self.batch_size return td def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed)