diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 4cfc8470a15..d2f0d11643a 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -16,7 +16,7 @@ TensorDictSequential as Seq, ) from torch.nn import functional as F -from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Bounded, Unbounded from torchrl.modules import MLP, QValueActor, TanhNormal from torchrl.objectives import ( A2CLoss, @@ -253,9 +253,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= value = Seq(common, value_head) value(actor(td)) - loss = SACLoss( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) benchmark(loss, td) @@ -312,9 +310,7 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden value = Seq(common, value_head) value(actor(td)) - loss = REDQLoss( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) benchmark(loss, td) @@ -373,9 +369,7 @@ def test_redq_deprec_speed( value = Seq(common, value_head) value(actor(td)) - loss = REDQLoss_deprecated( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) benchmark(loss, td) @@ -435,7 +429,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= loss = TD3Loss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + action_spec=Bounded(shape=(n_act,), low=-1, high=1), ) loss(td) @@ -490,9 +484,7 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= value = Seq(common, value_head) value(actor(td)) - loss = CQLLoss( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) benchmark(loss, td) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 0dca499f4d9..ed5639fcf59 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -877,11 +877,58 @@ TensorSpec .. _ref_specs: -The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such -as shape, device, dtype and domain. +The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations +actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain. + It is important that your environment specs match the input and output that it sends and receives, as -:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. -Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. +:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check. + +If needed, specs can be automatially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td` +function. + +Specs fall in two main categories, numerical and categorical. + +.. table:: Numerical TensorSpec subclasses. + + +-------------------------------------------------------------------------------+ + | Numerical | + +=====================================+=========================================+ + | Bounded | Unbounded | + +-----------------+-------------------+-------------------+---------------------+ + | BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous | + +-----------------+-------------------+-------------------+---------------------+ + +Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or +explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous` +or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class. +See these classes for further information. + +.. table:: Categorical TensorSpec subclasses. + + +------------------------------------------------------------------+ + | Categorical | + +========+=============+=============+==================+==========+ + | OneHot | MultiOneHot | Categorical | MultiCategorical | Binary | + +--------+-------------+-------------+------------------+----------+ + +Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be +combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as +:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these +cases is the :class:`~torchrl.data.Composite` spec. + +Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be +expanded accordingly. +Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class. + +Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and +:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``) +or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be. + +Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding +data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as +:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are +not predictable. .. currentmodule:: torchrl.data @@ -890,19 +937,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. :template: rl_template.rst TensorSpec + Binary + Bounded + Categorical + Composite + MultiCategorical + MultiOneHot + NonTensor + OneHotDiscrete + Stacked + StackedComposite + Unbounded + UnboundedContinuous + UnboundedDiscrete + +The following classes are deprecated and just point to the classes above: + +.. currentmodule:: torchrl.data + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + BinaryDiscreteTensorSpec BoundedTensorSpec CompositeSpec DiscreteTensorSpec + LazyStackedCompositeSpec + LazyStackedTensorSpec MultiDiscreteTensorSpec MultiOneHotDiscreteTensorSpec NonTensorSpec OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec UnboundedDiscreteTensorSpec - LazyStackedTensorSpec - LazyStackedCompositeSpec - NonTensorSpec Reinforcement Learning From Human Feedback (RLHF) ------------------------------------------------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 11a5bb041a6..283bd2a631b 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -28,9 +28,9 @@ Each env will have the following attributes: This is especially useful for transforms (see below). For parametric environments (e.g. model-based environments), the device does represent the hardware that will be used to compute the operations. -- :obj:`env.observation_spec`: a :class:`~torchrl.data.CompositeSpec` object +- :obj:`env.observation_spec`: a :class:`~torchrl.data.Composite` object containing all the observation key-spec pairs. -- :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object +- :obj:`env.state_spec`: a :class:`~torchrl.data.Composite` object containing all the input key-spec pairs (except action). For most stateful environments, this container will be empty. - :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object @@ -39,10 +39,10 @@ Each env will have the following attributes: the reward spec. - :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing the done-flag spec. See the section on trajectory termination below. -- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing +- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). It is locked and should not be modified directly. -- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing +- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). It is locked and should not be modified directly. @@ -433,28 +433,28 @@ only the done flag is shared across agents (as in VMAS): ... action_specs.append(agent_i_action_spec) ... reward_specs.append(agent_i_reward_spec) ... observation_specs.append(agent_i_observation_spec) - >>> env.action_spec = CompositeSpec( + >>> env.action_spec = Composite( ... { - ... "agents": CompositeSpec( + ... "agents": Composite( ... {"action": torch.stack(action_specs)}, shape=(env.n_agents,) ... ) ... } ...) - >>> env.reward_spec = CompositeSpec( + >>> env.reward_spec = Composite( ... { - ... "agents": CompositeSpec( + ... "agents": Composite( ... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,) ... ) ... } ...) - >>> env.observation_spec = CompositeSpec( + >>> env.observation_spec = Composite( ... { - ... "agents": CompositeSpec( + ... "agents": Composite( ... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,) ... ) ... } ...) - >>> env.done_spec = DiscreteTensorSpec( + >>> env.done_spec = Categorical( ... n=2, ... shape=torch.Size((1,)), ... dtype=torch.bool, @@ -582,23 +582,23 @@ the ``return_contiguous=False`` argument. Here is a working example: >>> from torchrl.envs import EnvBase - >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec + >>> from torchrl.data import Unbounded, Composite, Bounded, Binary >>> import torch >>> from tensordict import TensorDict, TensorDictBase >>> >>> class EnvWithDynamicSpec(EnvBase): ... def __init__(self, max_count=5): ... super().__init__(batch_size=()) - ... self.observation_spec = CompositeSpec( - ... observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)), + ... self.observation_spec = Composite( + ... observation=Unbounded(shape=(3, -1, 2)), ... ) - ... self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,)) - ... self.full_done_spec = CompositeSpec( - ... done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - ... terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - ... truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), + ... self.action_spec = Bounded(low=-1, high=1, shape=(2,)) + ... self.full_done_spec = Composite( + ... done=Binary(1, shape=(1,), dtype=torch.bool), + ... terminated=Binary(1, shape=(1,), dtype=torch.bool), + ... truncated=Binary(1, shape=(1,), dtype=torch.bool), ... ) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float) + ... self.reward_spec = Unbounded((1,), dtype=torch.float) ... self.count = 0 ... self.max_count = max_count ... diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 5b05fc32194..84603485f53 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -163,11 +163,91 @@ resulting action in the input tensordict along with the list of action values. >>> from tensordict import TensorDict >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) >>> # we have 4 actions to choose from - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) + >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available + >>> module = nn.Linear(3, 4) + >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + +Distributional Q-learning is slightly different: in this case, the value network +does not output a scalar value for each state-action value. +Instead, the value space is divided in a an arbitrary number of "bins". The +value network outputs a probability that the state-action value belongs to one bin +or another. +Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, +the value network encodes a +of a (s,a) -> v map. This map can be a table or a function. +For discrete action spaces with continuous (or near-continuous such as pixels) +states, it is customary to use a non-linear model such as a neural network for +the map. +The semantic of the Q-Value network is hopefully quite simple: we just need to +feed a tensor-to-tensor map that given a certain state (the input tensor), +outputs a list of action values to choose from. The wrapper will write the +resulting action in the input tensordict along with the list of action values. + + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional + >>> from torch import nn + >>> from torchrl.data import OneHot + >>> from torchrl.modules.tensordict_module.actors import QValueActor + >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) + >>> # we have 4 actions to choose from + >>> action_spec = OneHot(4) + >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available + >>> module = nn.Linear(3, 4) + >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + +Distributional Q-learning is slightly different: in this case, the value network +does not output a scalar value for each state-action value. +Instead, the value space is divided in a an arbitrary number of "bins". The +value network outputs a probability that the state-action value belongs to one bin +or another. +Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, +the value network encodes a +of a (s,a) -> v map. This map can be a table or a function. +For discrete action spaces with continuous (or near-continuous such as pixels) +states, it is customary to use a non-linear model such as a neural network for +the map. +The semantic of the Q-Value network is hopefully quite simple: we just need to +feed a tensor-to-tensor map that given a certain state (the input tensor), +outputs a list of action values to choose from. The wrapper will write the +resulting action in the input tensordict along with the list of action values. + + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional + >>> from torch import nn + >>> from torchrl.data import OneHot + >>> from torchrl.modules.tensordict_module.actors import QValueActor + >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) + >>> # we have 4 actions to choose from + >>> action_spec = OneHot(4) >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available >>> module = nn.Linear(3, 4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) @@ -196,13 +276,57 @@ class: >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot + >>> from torchrl.modules import DistributionalQValueActor, MLP + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> nbins = 3 + >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 + >>> module = MLP(out_features=(nbins, 4), depth=2) + >>> action_spec = OneHot(4) + >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) + >>> td = qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + + >>> import torch + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from torchrl.data import OneHot + >>> from torchrl.modules import DistributionalQValueActor, MLP + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> nbins = 3 + >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 + >>> module = MLP(out_features=(nbins, 4), depth=2) + >>> action_spec = OneHot(4) + >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) + >>> td = qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + + >>> import torch + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) >>> td = qvalue_actor(td) >>> print(td) diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index b140ee7bc67..7b7e053f498 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -114,7 +114,7 @@ def main(): import gym from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector - from torchrl.data import BoundedTensorSpec + from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import RandomPolicy @@ -128,7 +128,7 @@ def make_env(): collector = DistributedDataCollector( [EnvCreator(make_env)] * num_jobs, - policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), + policy=RandomPolicy(Bounded(-1, 1, shape=(1,))), launcher="submitit_delayed", frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index adff8864413..f63c4d17409 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -113,7 +113,7 @@ def main(): import gym from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector - from torchrl.data import BoundedTensorSpec + from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import RandomPolicy @@ -127,7 +127,7 @@ def make_env(): collector = RPCDataCollector( [EnvCreator(make_env)] * num_jobs, - policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), + policy=RandomPolicy(Bounded(-1, 1, shape=(1,))), launcher="submitit_delayed", frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/examples/envs/gym-async-info-reader.py b/examples/envs/gym-async-info-reader.py index 3f98e039290..72330f13030 100644 --- a/examples/envs/gym-async-info-reader.py +++ b/examples/envs/gym-async-info-reader.py @@ -48,7 +48,7 @@ def step(self, action): if __name__ == "__main__": import torch - from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + from torchrl.data.tensor_specs import Unbounded from torchrl.envs import check_env_specs, GymEnv, GymWrapper args = parser.parse_args() @@ -66,7 +66,7 @@ def step(self, action): keys = ["field1"] specs = [ - UnboundedContinuousTensorSpec(shape=(num_envs, 3), dtype=torch.float64), + Unbounded(shape=(num_envs, 3), dtype=torch.float64), ] # Create an info reader: this object will read the info and write its content to the tensordict diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index f8c18147306..42ef4301c4d 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -226,6 +226,9 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index d115174eb9c..2b390d39d2a 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -212,6 +212,9 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 58fa8541d90..6a09ff715e4 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -7,8 +7,8 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec -from torchrl.data.tensor_specs import DiscreteBox +from torchrl.data import Composite +from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -92,7 +92,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): + if isinstance(proof_environment.action_spec.space, CategoricalBox): num_outputs = proof_environment.action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 9bb5a1f6307..996706ce4f9 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -8,7 +8,7 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -90,7 +90,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index fae54da049a..c1d6fb52024 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -11,7 +11,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, @@ -252,7 +252,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"): actor_net = MLP(**actor_net_kwargs) qvalue_module = QValueActor( module=actor_net, - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=["observation"], ) qvalue_module = qvalue_module.to(device) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index c5a1b88eea3..b07ae880046 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -220,6 +220,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 1b038d69d15..cebc3685625 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -205,6 +205,10 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 9cca9fd8af5..b892462339c 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -131,6 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() + if not test_env.is_closed: + test_env.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index da2241ce9fa..184c850b626 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -145,6 +145,8 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() + if not test_env.is_closed: + test_env.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 386f743c7d3..a9a08827f5d 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -222,6 +222,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py index ddffffc2a8e..8051f07fe95 100644 --- a/sota-implementations/discrete_sac/utils.py +++ b/sota-implementations/discrete_sac/utils.py @@ -12,7 +12,7 @@ from torch import nn, optim from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) @@ -203,7 +203,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): out_keys=["logits"], ) actor = ProbabilisticActor( - spec=CompositeSpec(action=eval_env.action_spec), + spec=Composite(action=eval_env.action_spec), module=actor_module, in_keys=["logits"], out_keys=["action"], diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 906273ee2f5..5d0162080e2 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -228,6 +228,9 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not test_env.is_closed: + test_env.close() + end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 173f88f7028..8149c700958 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -207,6 +207,8 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 3dbbfe87af4..6f39e824c60 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -5,7 +5,7 @@ import torch.nn import torch.optim -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -84,7 +84,7 @@ def make_dqn_modules_pixels(proof_environment): ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=["pixels"], ) return qvalue_module diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index 2df280a04b4..c7f7491ad15 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -5,7 +5,7 @@ import torch.nn import torch.optim -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import RewardSum, StepCounter, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, QValueActor @@ -48,7 +48,7 @@ def make_dqn_modules(proof_environment): qvalue_module = QValueActor( module=mlp, - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=["observation"], ) return qvalue_module diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 6745b1a079a..849d8c813b6 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -20,11 +20,11 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, LazyMemmapStorage, SliceSampler, TensorDictReplayBuffer, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs import ( @@ -92,8 +92,8 @@ def _make_env(cfg, device, from_pixels=False): else: raise NotImplementedError(f"Unknown lib {lib}.") default_dict = { - "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim,)), - "belief": UnboundedContinuousTensorSpec(shape=(cfg.networks.rssm_hidden_dim,)), + "state": Unbounded(shape=(cfg.networks.state_dim,)), + "belief": Unbounded(shape=(cfg.networks.rssm_hidden_dim,)), } env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -469,13 +469,13 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), @@ -488,7 +488,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec(**{action_key: proof_environment.action_spec}), + spec=Composite(**{action_key: proof_environment.action_spec}), ), ) return actor_simulator @@ -526,12 +526,12 @@ def _dreamer_make_actor_real( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, ), } @@ -543,9 +543,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec( - **{action_key: proof_environment.action_spec.to("cpu")} - ), + spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}), ), ), SafeModule( diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index b365dca3867..9fa3d6b399f 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -6,7 +6,7 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index d1a16fd8192..53581782d20 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -141,6 +141,10 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index d50ff806294..3cdff06ffa2 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -204,6 +204,12 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time + + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 61d31b88eb8..a24c6168375 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -11,7 +11,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, @@ -306,7 +306,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): out_keys=["logits"], ) actor = ProbabilisticActor( - spec=CompositeSpec(action=eval_env.action_spec), + spec=Composite(action=eval_env.action_spec), module=actor_module, in_keys=["logits"], out_keys=["action"], diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index a4d2b88a9d0..39750c5d425 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -225,6 +225,12 @@ def train(cfg: "DictConfig"): # noqa: F821 logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() + if __name__ == "__main__": train() diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index e9de2ac4e14..aad1df14fff 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -251,6 +251,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index fa006a7d4a2..d2e218b843a 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -254,6 +254,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index 4e6a962c556..c5993f902c6 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -259,6 +259,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index f7b2523010b..cfafdd47c96 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -318,6 +318,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 2b02254032a..6d8883393d5 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -243,6 +243,9 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not test_env.is_closed: + test_env.close() + end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 219ae1b59b6..8cfea74d0bc 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -235,6 +235,9 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 2344da518bc..50f91ed49cd 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -6,8 +6,8 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec -from torchrl.data.tensor_specs import DiscreteBox +from torchrl.data import Composite +from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -92,7 +92,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): + if isinstance(proof_environment.action_spec.space, CategoricalBox): num_outputs = proof_environment.action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 7986738f8e6..a05d205b000 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -7,7 +7,7 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index 9904fe072ab..68860500149 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -215,6 +215,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 632ee58503d..01a59686ac9 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -213,6 +213,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index b3e8ed3b880..930ff509488 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -138,6 +138,8 @@ def main(cfg: "DictConfig"): # noqa: F821 if logger is not None: log_metrics(logger, to_log, i) + if not eval_env.is_closed: + eval_env.close() pbar.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ea4327bb460..795fda399de 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -11,15 +11,15 @@ from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - NonTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Bounded, + Categorical, + Composite, + MultiOneHot, + NonTensor, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.data.utils import consolidate_spec from torchrl.envs.common import EnvBase @@ -27,27 +27,27 @@ from torchrl.envs.utils import _terminated_or_truncated spec_dict = { - "bounded": BoundedTensorSpec, - "one_hot": OneHotDiscreteTensorSpec, - "categorical": DiscreteTensorSpec, - "unbounded": UnboundedContinuousTensorSpec, - "binary": BinaryDiscreteTensorSpec, - "mult_one_hot": MultiOneHotDiscreteTensorSpec, - "composite": CompositeSpec, + "bounded": Bounded, + "one_hot": OneHot, + "categorical": Categorical, + "unbounded": Unbounded, + "binary": Binary, + "mult_one_hot": MultiOneHot, + "composite": Composite, } default_spec_kwargs = { - OneHotDiscreteTensorSpec: {"n": 7}, - DiscreteTensorSpec: {"n": 7}, - BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, - UnboundedContinuousTensorSpec: { + OneHot: {"n": 7}, + Categorical: {"n": 7}, + Bounded: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, + Unbounded: { "shape": [ 7, ] }, - BinaryDiscreteTensorSpec: {"n": 7}, - MultiOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]}, - CompositeSpec: {}, + Binary: {"n": 7}, + MultiOneHot: {"nvec": [7, 3, 5]}, + Composite: {}, } @@ -68,8 +68,8 @@ def __new__( torch.get_default_dtype() ) reward_spec = cls._output_spec["full_reward_spec"] - if isinstance(reward_spec, CompositeSpec): - reward_spec = CompositeSpec( + if isinstance(reward_spec, Composite): + reward_spec = Composite( { key: item.to(torch.get_default_dtype()) for key, item in reward_spec.items(True, True) @@ -80,19 +80,19 @@ def __new__( else: reward_spec = reward_spec.to(torch.get_default_dtype()) cls._output_spec["full_reward_spec"] = reward_spec - if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): - cls._output_spec["full_reward_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_reward_spec"], Composite): + cls._output_spec["full_reward_spec"] = Composite( reward=cls._output_spec["full_reward_spec"], shape=cls._output_spec["full_reward_spec"].shape[:-1], ) - if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): - cls._output_spec["full_done_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_done_spec"], Composite): + cls._output_spec["full_done_spec"] = Composite( done=cls._output_spec["full_done_spec"].clone(), terminated=cls._output_spec["full_done_spec"].clone(), shape=cls._output_spec["full_done_spec"].shape[:-1], ) - if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): - cls._input_spec["full_action_spec"] = CompositeSpec( + if not isinstance(cls._input_spec["full_action_spec"], Composite): + cls._input_spec["full_action_spec"] = Composite( action=cls._input_spec["full_action_spec"], shape=cls._input_spec["full_action_spec"].shape[:-1], ) @@ -156,15 +156,15 @@ def __new__( ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec( + action_spec = Unbounded( ( *batch_size, 1, ) ) if observation_spec is None: - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + observation_spec = Composite( + observation=Unbounded( ( *batch_size, 1, @@ -173,35 +173,35 @@ def __new__( shape=batch_size, ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( ( *batch_size, 1, ) ) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: - state_spec = CompositeSpec(shape=batch_size) - input_spec = CompositeSpec( + state_spec = Composite(shape=batch_size) + input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size ) - cls._output_spec = CompositeSpec(shape=batch_size) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = input_spec - if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): - cls._output_spec["full_reward_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_reward_spec"], Composite): + cls._output_spec["full_reward_spec"] = Composite( reward=cls._output_spec["full_reward_spec"], shape=batch_size ) - if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): - cls._output_spec["full_done_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_done_spec"], Composite): + cls._output_spec["full_done_spec"] = Composite( done=cls._output_spec["full_done_spec"], shape=batch_size ) - if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): - cls._input_spec["full_action_spec"] = CompositeSpec( + if not isinstance(cls._input_spec["full_action_spec"], Composite): + cls._input_spec["full_action_spec"] = Composite( action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(*args, **kwargs) @@ -268,15 +268,15 @@ def __new__( ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec( + action_spec = Unbounded( ( *batch_size, 1, ) ) if state_spec is None: - state_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + state_spec = Composite( + observation=Unbounded( ( *batch_size, 1, @@ -285,8 +285,8 @@ def __new__( shape=batch_size, ) if observation_spec is None: - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + observation_spec = Composite( + observation=Unbounded( ( *batch_size, 1, @@ -295,33 +295,33 @@ def __new__( shape=batch_size, ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( ( *batch_size, 1, ) ) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) - cls._output_spec = CompositeSpec(shape=batch_size) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec - cls._input_spec = CompositeSpec( + cls._input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, ) - if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): - cls._output_spec["full_reward_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_reward_spec"], Composite): + cls._output_spec["full_reward_spec"] = Composite( reward=cls._output_spec["full_reward_spec"], shape=batch_size ) - if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): - cls._output_spec["full_done_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_done_spec"], Composite): + cls._output_spec["full_done_spec"] = Composite( done=cls._output_spec["full_done_spec"], shape=batch_size ) - if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): - cls._input_spec["full_action_spec"] = CompositeSpec( + if not isinstance(cls._input_spec["full_action_spec"], Composite): + cls._input_spec["full_action_spec"] = Composite( action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(cls, *args, **kwargs) @@ -442,46 +442,38 @@ def __new__( size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), - observation_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), + observation_spec = Composite( + observation=Unbounded(shape=torch.Size([*batch_size, size])), + observation_orig=Unbounded(shape=torch.Size([*batch_size, size])), shape=batch_size, ) if action_spec is None: if categorical_action_encoding: - action_spec_cls = DiscreteTensorSpec + action_spec_cls = Categorical action_spec = action_spec_cls(n=7, shape=batch_size) else: - action_spec_cls = OneHotDiscreteTensorSpec + action_spec_cls = OneHot action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = CompositeSpec( - reward=UnboundedContinuousTensorSpec(shape=(1,)) - ) + reward_spec = Composite(reward=Unbounded(shape=(1,))) if done_spec is None: - done_spec = CompositeSpec( - terminated=DiscreteTensorSpec( - 2, dtype=torch.bool, shape=(*batch_size, 1) - ) + done_spec = Composite( + terminated=Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) ) if state_spec is None: cls._out_key = "observation_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["observation"], }, shape=batch_size, ) - cls._output_spec = CompositeSpec(shape=batch_size) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec - cls._input_spec = CompositeSpec( + cls._input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, @@ -553,17 +545,13 @@ def __new__( size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), - observation_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), + observation_spec = Composite( + observation=Unbounded(shape=torch.Size([*batch_size, size])), + observation_orig=Unbounded(shape=torch.Size([*batch_size, size])), shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec( + action_spec = Bounded( -1, 1, ( @@ -572,23 +560,23 @@ def __new__( ), ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) + reward_spec = Unbounded(shape=(*batch_size, 1)) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "observation_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["observation"], }, shape=batch_size, ) - cls._output_spec = CompositeSpec(shape=batch_size) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec - cls._input_spec = CompositeSpec( + cls._input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, @@ -681,25 +669,21 @@ def __new__( batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 1, 7, 7]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 1, 7, 7]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])), shape=batch_size, ) if action_spec is None: - action_spec = OneHotDiscreteTensorSpec(7, shape=(*batch_size, 7)) + action_spec = OneHot(7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) + reward_spec = Unbounded(shape=(*batch_size, 1)) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "pixels_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["pixels_orig"].clone(), }, @@ -741,25 +725,17 @@ def __new__( batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), shape=batch_size, ) if action_spec is None: - action_spec_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_spec_cls = Categorical if categorical_action_encoding else OneHot action_spec = action_spec_cls(7, shape=(*batch_size, 7)) if state_spec is None: cls._out_key = "pixels_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["pixels_orig"], }, @@ -808,25 +784,21 @@ def __new__( pixel_shape = [1, 7, 7] if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, *pixel_shape]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, *pixel_shape]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])), shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec(-1, 1, [*batch_size, pixel_shape[-1]]) + action_spec = Bounded(-1, 1, [*batch_size, pixel_shape[-1]]) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) + reward_spec = Unbounded(shape=(*batch_size, 1)) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "pixels_orig" - state_spec = CompositeSpec( + state_spec = Composite( {cls._out_key: observation_spec["pixels"]}, shape=batch_size ) return super().__new__( @@ -865,13 +837,9 @@ def __new__( batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), ) return super().__new__( *args, @@ -928,8 +896,8 @@ def __init__( device=device, batch_size=batch_size, ) - self.observation_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + hidden_observation=Unbounded( ( *self.batch_size, 4, @@ -937,8 +905,8 @@ def __init__( ), shape=self.batch_size, ) - self.state_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec( + self.state_spec = Composite( + hidden_observation=Unbounded( ( *self.batch_size, 4, @@ -946,13 +914,13 @@ def __init__( ), shape=self.batch_size, ) - self.action_spec = UnboundedContinuousTensorSpec( + self.action_spec = Unbounded( ( *self.batch_size, 1, ) ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( ( *self.batch_size, 1, @@ -1012,8 +980,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): self.max_steps = max_steps self.start_val = start_val - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + observation=Unbounded( ( *self.batch_size, 1, @@ -1024,14 +992,14 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): shape=self.batch_size, device=self.device, ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( ( *self.batch_size, 1, ), device=self.device, ) - self.done_spec = DiscreteTensorSpec( + self.done_spec = Categorical( 2, dtype=torch.bool, shape=( @@ -1040,9 +1008,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): ), device=self.device, ) - self.action_spec = BinaryDiscreteTensorSpec( - n=1, shape=[*self.batch_size, 1], device=self.device - ) + self.action_spec = Binary(n=1, shape=[*self.batch_size, 1], device=self.device) self.register_buffer( "count", torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), @@ -1129,9 +1095,9 @@ def __init__( self.nested_reward = nest_reward if self.nested_obs_action: - self.observation_spec = CompositeSpec( + self.observation_spec = Composite( { - "data": CompositeSpec( + "data": Composite( { "states": self.observation_spec["observation"] .unsqueeze(-1) @@ -1145,9 +1111,9 @@ def __init__( }, shape=self.batch_size, ) - self.action_spec = CompositeSpec( + self.action_spec = Composite( { - "data": CompositeSpec( + "data": Composite( { "action": self.action_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim, 1 @@ -1163,9 +1129,9 @@ def __init__( ) if self.nested_reward: - self.reward_spec = CompositeSpec( + self.reward_spec = Composite( { - "data": CompositeSpec( + "data": Composite( { "reward": self.reward_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim, 1 @@ -1184,12 +1150,12 @@ def __init__( done_spec = self.full_done_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim ) - done_spec = CompositeSpec( + done_spec = Composite( {"data": done_spec}, shape=self.batch_size, ) if self.has_root_done: - done_spec["done"] = DiscreteTensorSpec( + done_spec["done"] = Categorical( 2, shape=( *self.batch_size, @@ -1309,8 +1275,8 @@ def __init__( self.max_steps = max_steps - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + observation=Unbounded( ( *self.batch_size, 1, @@ -1319,13 +1285,13 @@ def __init__( ), shape=self.batch_size, ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( ( *self.batch_size, 1, ) ) - self.done_spec = DiscreteTensorSpec( + self.done_spec = Categorical( 2, dtype=torch.bool, shape=( @@ -1333,7 +1299,7 @@ def __init__( 1, ), ) - self.action_spec = BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]) + self.action_spec = Binary(n=1, shape=[*self.batch_size, 1]) self.count = torch.zeros( (*self.batch_size, 1), device=self.device, dtype=torch.int @@ -1419,34 +1385,30 @@ def _make_specs(self): obs_spec_unlazy = consolidate_spec(obs_specs) action_specs = torch.stack(action_specs, dim=0) - self.unbatched_observation_spec = CompositeSpec( + self.unbatched_observation_spec = Composite( lazy=obs_spec_unlazy, - state=UnboundedContinuousTensorSpec(shape=(64, 64, 3)), + state=Unbounded(shape=(64, 64, 3)), device=self.device, ) - self.unbatched_action_spec = CompositeSpec( + self.unbatched_action_spec = Composite( lazy=action_specs, device=self.device, ) - self.unbatched_reward_spec = CompositeSpec( + self.unbatched_reward_spec = Composite( { - "lazy": CompositeSpec( - { - "reward": UnboundedContinuousTensorSpec( - shape=(self.n_nested_dim, 1) - ) - }, + "lazy": Composite( + {"reward": Unbounded(shape=(self.n_nested_dim, 1))}, shape=(self.n_nested_dim,), ) }, device=self.device, ) - self.unbatched_done_spec = CompositeSpec( + self.unbatched_done_spec = Composite( { - "lazy": CompositeSpec( + "lazy": Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=(self.n_nested_dim, 1), dtype=torch.bool, @@ -1472,17 +1434,17 @@ def _make_specs(self): ) def get_agent_obs_spec(self, i): - camera = BoundedTensorSpec(low=0, high=200, shape=(7, 7, 3)) - vector_3d = UnboundedContinuousTensorSpec(shape=(3,)) - vector_2d = UnboundedContinuousTensorSpec(shape=(2,)) - lidar = BoundedTensorSpec(low=0, high=5, shape=(8,)) + camera = Bounded(low=0, high=200, shape=(7, 7, 3)) + vector_3d = Unbounded(shape=(3,)) + vector_2d = Unbounded(shape=(2,)) + lidar = Bounded(low=0, high=5, shape=(8,)) - tensor_0 = UnboundedContinuousTensorSpec(shape=(1,)) - tensor_1 = BoundedTensorSpec(low=0, high=3, shape=(1, 2)) - tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3)) + tensor_0 = Unbounded(shape=(1,)) + tensor_1 = Bounded(low=0, high=3, shape=(1, 2)) + tensor_2 = Unbounded(shape=(1, 2, 3)) if i == 0: - return CompositeSpec( + return Composite( { "camera": camera, "lidar": lidar, @@ -1492,7 +1454,7 @@ def get_agent_obs_spec(self, i): device=self.device, ) elif i == 1: - return CompositeSpec( + return Composite( { "camera": camera, "lidar": lidar, @@ -1502,7 +1464,7 @@ def get_agent_obs_spec(self, i): device=self.device, ) elif i == 2: - return CompositeSpec( + return Composite( { "camera": camera, "vector": vector_2d, @@ -1514,8 +1476,8 @@ def get_agent_obs_spec(self, i): raise ValueError(f"Index {i} undefined for index 3") def get_agent_action_spec(self, i): - action_3d = BoundedTensorSpec(low=-1, high=1, shape=(3,)) - action_2d = BoundedTensorSpec(low=-1, high=1, shape=(2,)) + action_3d = Bounded(low=-1, high=1, shape=(3,)) + action_2d = Bounded(low=-1, high=1, shape=(2,)) # Some have 2d action and some 3d # TODO Introduce composite heterogeneous actions @@ -1528,7 +1490,7 @@ def get_agent_action_spec(self, i): else: raise ValueError(f"Index {i} undefined for index 3") - return CompositeSpec({"action": ret}) + return Composite({"action": ret}) def _reset( self, @@ -1659,18 +1621,16 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): ) def make_specs(self): - self.unbatched_observation_spec = CompositeSpec( - nested_1=CompositeSpec( - observation=BoundedTensorSpec( - low=0, high=200, shape=(self.nested_dim_1, 3) - ), + self.unbatched_observation_spec = Composite( + nested_1=Composite( + observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 2)), + nested_2=Composite( + observation=Unbounded(shape=(self.nested_dim_2, 2)), shape=(self.nested_dim_2,), ), - observation=UnboundedContinuousTensorSpec( + observation=Unbounded( shape=( 10, 10, @@ -1679,51 +1639,51 @@ def make_specs(self): ), ) - self.unbatched_action_spec = CompositeSpec( - nested_1=CompositeSpec( - action=DiscreteTensorSpec(n=2, shape=(self.nested_dim_1,)), + self.unbatched_action_spec = Composite( + nested_1=Composite( + action=Categorical(n=2, shape=(self.nested_dim_1,)), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - azione=BoundedTensorSpec(low=0, high=100, shape=(self.nested_dim_2, 1)), + nested_2=Composite( + azione=Bounded(low=0, high=100, shape=(self.nested_dim_2, 1)), shape=(self.nested_dim_2,), ), - action=OneHotDiscreteTensorSpec(n=2), + action=OneHot(n=2), ) - self.unbatched_reward_spec = CompositeSpec( - nested_1=CompositeSpec( - gift=UnboundedContinuousTensorSpec(shape=(self.nested_dim_1, 1)), + self.unbatched_reward_spec = Composite( + nested_1=Composite( + gift=Unbounded(shape=(self.nested_dim_1, 1)), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - reward=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 1)), + nested_2=Composite( + reward=Unbounded(shape=(self.nested_dim_2, 1)), shape=(self.nested_dim_2,), ), - reward=UnboundedContinuousTensorSpec(shape=(1,)), + reward=Unbounded(shape=(1,)), ) - self.unbatched_done_spec = CompositeSpec( - nested_1=CompositeSpec( - done=DiscreteTensorSpec( + self.unbatched_done_spec = Composite( + nested_1=Composite( + done=Categorical( n=2, shape=(self.nested_dim_1, 1), dtype=torch.bool, ), - terminated=DiscreteTensorSpec( + terminated=Categorical( n=2, shape=(self.nested_dim_1, 1), dtype=torch.bool, ), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - done=DiscreteTensorSpec( + nested_2=Composite( + done=Categorical( n=2, shape=(self.nested_dim_2, 1), dtype=torch.bool, ), - terminated=DiscreteTensorSpec( + terminated=Categorical( n=2, shape=(self.nested_dim_2, 1), dtype=torch.bool, @@ -1731,12 +1691,12 @@ def make_specs(self): shape=(self.nested_dim_2,), ), # done at the root always prevail - done=DiscreteTensorSpec( + done=Categorical( n=2, shape=(1,), dtype=torch.bool, ), - terminated=DiscreteTensorSpec( + terminated=Categorical( n=2, shape=(1,), dtype=torch.bool, @@ -1829,15 +1789,15 @@ def _set_seed(self, seed: Optional[int]): class EnvWithMetadata(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - tensor=UnboundedContinuousTensorSpec(3), - non_tensor=NonTensorSpec(shape=()), + self.observation_spec = Composite( + tensor=Unbounded(3), + non_tensor=NonTensor(shape=()), ) - self.state_spec = CompositeSpec( - non_tensor=NonTensorSpec(shape=()), + self.state_spec = Composite( + non_tensor=NonTensor(shape=()), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(1) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(1) def _reset(self, tensordict): data = self.observation_spec.zero() @@ -1935,16 +1895,16 @@ def _reset(self, tensordict=None): class EnvWithDynamicSpec(EnvBase): def __init__(self, max_count=5): super().__init__(batch_size=()) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)), + self.observation_spec = Composite( + observation=Unbounded(shape=(3, -1, 2)), ) - self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,)) - self.full_done_spec = CompositeSpec( - done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), + self.action_spec = Bounded(low=-1, high=1, shape=(2,)) + self.full_done_spec = Composite( + done=Binary(1, shape=(1,), dtype=torch.bool), + terminated=Binary(1, shape=(1,), dtype=torch.bool), + truncated=Binary(1, shape=(1,), dtype=torch.bool), ) - self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float) + self.reward_spec = Unbounded((1,), dtype=torch.float) self.count = 0 self.max_count = max_count diff --git a/test/test_actors.py b/test/test_actors.py index 2d160e31bba..439094e922a 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -14,14 +14,7 @@ from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist, nn -from torchrl.data import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, -) +from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot from torchrl.data.rlhf.dataset import _has_transformers from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal from torchrl.modules.tensordict_module.actors import ( @@ -50,9 +43,7 @@ ) def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3): env = NestedCountingEnv(nested_dim=nested_dim) - action_spec = BoundedTensorSpec( - shape=torch.Size((nested_dim, n_actions)), high=1, low=-1 - ) + action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) policy_module = TensorDictModule( nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] ) @@ -111,9 +102,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= ) def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions=3): env = NestedCountingEnv(nested_dim=nested_dim) - action_spec = BoundedTensorSpec( - shape=torch.Size((nested_dim, n_actions)), high=1, low=-1 - ) + action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) actor_net = nn.Sequential( nn.Linear(1, 2), NormalParamExtractor(), @@ -181,7 +170,7 @@ def test_distributional_qvalue_hook_wrong_action_space(self): DistributionalQValueHook(action_space="wrong_value", support=None) def test_distributional_qvalue_hook_conflicting_spec(self): - spec = OneHotDiscreteTensorSpec(3) + spec = OneHot(3) _process_action_space_spec("one-hot", spec) _process_action_space_spec("one_hot", spec) _process_action_space_spec("one_hot", None) @@ -190,19 +179,19 @@ def test_distributional_qvalue_hook_conflicting_spec(self): ValueError, match="The action spec and the action space do not match" ): _process_action_space_spec("multi-one-hot", spec) - spec = MultiOneHotDiscreteTensorSpec([3, 3]) + spec = MultiOneHot([3, 3]) _process_action_space_spec("multi-one-hot", spec) _process_action_space_spec(spec, spec) with pytest.raises( ValueError, match="Passing an action_space as a TensorSpec and a spec" ): - _process_action_space_spec(OneHotDiscreteTensorSpec(3), spec) + _process_action_space_spec(OneHot(3), spec) with pytest.raises( - ValueError, match="action_space cannot be of type CompositeSpec" + ValueError, match="action_space cannot be of type Composite" ): - _process_action_space_spec(CompositeSpec(), spec) + _process_action_space_spec(Composite(), spec) with pytest.raises(KeyError, match="action could not be found in the spec"): - _process_action_space_spec(None, CompositeSpec()) + _process_action_space_spec(None, Composite()) with pytest.raises( ValueError, match="Neither action_space nor spec was defined" ): @@ -248,10 +237,10 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5): ValueError, match="Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match.", ): - _process_action_space_spec(BinaryDiscreteTensorSpec(n=1), action_spec) - _process_action_space_spec(BinaryDiscreteTensorSpec(n=1), leaf_action_spec) + _process_action_space_spec(Binary(n=1), action_spec) + _process_action_space_spec(Binary(n=1), leaf_action_spec) with pytest.raises( - ValueError, match="action_space cannot be of type CompositeSpec" + ValueError, match="action_space cannot be of type Composite" ): _process_action_space_spec(action_spec, None) @@ -652,7 +641,7 @@ def test_value_based_policy(device): torch.manual_seed(0) obs_dim = 4 action_dim = 5 - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) def make_net(): net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device) @@ -681,9 +670,7 @@ def make_net(): assert (action.sum(-1) == 1).all() -@pytest.mark.parametrize( - "spec", [None, OneHotDiscreteTensorSpec(3), MultiOneHotDiscreteTensorSpec([3, 2])] -) +@pytest.mark.parametrize("spec", [None, OneHot(3), MultiOneHot([3, 2])]) @pytest.mark.parametrize( "action_space", [None, "one-hot", "one_hot", "mult-one-hot", "mult_one_hot"] ) @@ -706,12 +693,9 @@ def test_qvalactor_construct( QValueActor(**kwargs) return if ( - type(spec) is MultiOneHotDiscreteTensorSpec + type(spec) is MultiOneHot and action_space not in ("mult-one-hot", "mult_one_hot", None) - ) or ( - type(spec) is OneHotDiscreteTensorSpec - and action_space not in ("one-hot", "one_hot", None) - ): + ) or (type(spec) is OneHot and action_space not in ("one-hot", "one_hot", None)): with pytest.raises( ValueError, match="The action spec and the action space do not match" ): @@ -725,7 +709,7 @@ def test_value_based_policy_categorical(device): torch.manual_seed(0) obs_dim = 4 action_dim = 5 - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) def make_net(): net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device) diff --git a/test/test_collector.py b/test/test_collector.py index 7d7208aead0..9b0117e7486 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -68,12 +68,12 @@ ) from torchrl.collectors.utils import split_trajectories from torchrl.data import ( - CompositeSpec, + Composite, LazyTensorStorage, - NonTensorSpec, + NonTensor, ReplayBuffer, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs import ( EnvBase, @@ -210,22 +210,16 @@ class DeviceLessEnv(EnvBase): def __init__(self, default_device): self.default_device = default_device super().__init__(device=None) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((), device=default_device) + self.observation_spec = Composite( + observation=Unbounded((), device=default_device) ) - self.reward_spec = UnboundedContinuousTensorSpec(1, device=default_device) - self.full_done_spec = CompositeSpec( - done=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - truncated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - terminated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), + self.reward_spec = Unbounded(1, device=default_device) + self.full_done_spec = Composite( + done=Unbounded(1, dtype=torch.bool, device=self.default_device), + truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), + terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), ) - self.action_spec = UnboundedContinuousTensorSpec((), device=None) + self.action_spec = Unbounded((), device=None) assert self.device is None assert self.full_observation_spec is not None assert self.full_done_spec is not None @@ -268,29 +262,17 @@ class EnvWithDevice(EnvBase): def __init__(self, default_device): self.default_device = default_device super().__init__(device=self.default_device) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( - (), device=self.default_device - ) - ) - self.reward_spec = UnboundedContinuousTensorSpec( - 1, device=self.default_device + self.observation_spec = Composite( + observation=Unbounded((), device=self.default_device) ) - self.full_done_spec = CompositeSpec( - done=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - truncated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - terminated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), + self.reward_spec = Unbounded(1, device=self.default_device) + self.full_done_spec = Composite( + done=Unbounded(1, dtype=torch.bool, device=self.default_device), + truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), + terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), device=self.default_device, ) - self.action_spec = UnboundedContinuousTensorSpec( - (), device=self.default_device - ) + self.action_spec = Unbounded((), device=self.default_device) assert self.device == _make_ordinal_device( torch.device(self.default_device) ) @@ -1295,7 +1277,7 @@ def make_env(): policy, copier, OrnsteinUhlenbeckProcessModule( - spec=CompositeSpec({key: None for key in policy.out_keys}) + spec=Composite({key: None for key in policy.out_keys}) ), ) @@ -1368,12 +1350,12 @@ def test_collector_output_keys( ], } if explicit_spec: - hidden_spec = UnboundedContinuousTensorSpec((1, hidden_size)) - policy_kwargs["spec"] = CompositeSpec( - action=UnboundedContinuousTensorSpec(), + hidden_spec = Unbounded((1, hidden_size)) + policy_kwargs["spec"] = Composite( + action=Unbounded(), hidden1=hidden_spec, hidden2=hidden_spec, - next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec), + next=Composite(hidden1=hidden_spec, hidden2=hidden_spec), ) policy = SafeModule(**policy_kwargs) @@ -2170,15 +2152,9 @@ class DummyEnv(EnvBase): def __init__(self, device, batch_size=[]): # noqa: B006 super().__init__(batch_size=batch_size, device=device) self.state = torch.zeros(self.batch_size, device=device) - self.observation_spec = CompositeSpec( - state=UnboundedContinuousTensorSpec(shape=(), device=device) - ) - self.action_spec = UnboundedContinuousTensorSpec( - shape=batch_size, device=device - ) - self.reward_spec = UnboundedContinuousTensorSpec( - shape=(*batch_size, 1), device=device - ) + self.observation_spec = Composite(state=Unbounded(shape=(), device=device)) + self.action_spec = Unbounded(shape=batch_size, device=device) + self.reward_spec = Unbounded(shape=(*batch_size, 1), device=device) def _step( self, @@ -2685,7 +2661,7 @@ def _reset( def transform_observation_spec( self, observation_spec: TensorSpec ) -> TensorSpec: - observation_spec["nt"] = NonTensorSpec(shape=()) + observation_spec["nt"] = NonTensor(shape=()) return observation_spec @classmethod diff --git a/test/test_cost.py b/test/test_cost.py index 6192e45c113..30ccb2e153b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -54,14 +54,7 @@ from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv @@ -304,9 +297,9 @@ def _create_mock_actor( ): # Actor if action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) # elif action_spec_type == "nd_bounded": # action_spec = BoundedTensorSpec( # -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) @@ -318,7 +311,7 @@ def _create_mock_actor( if is_nn_module: return module.to(device) actor = QValueActor( - spec=CompositeSpec( + spec=Composite( { "action": action_spec, ( @@ -349,14 +342,12 @@ def _create_mock_distributional_actor( # Actor var_nums = None if action_spec_type == "mult_one_hot": - action_spec = MultiOneHotDiscreteTensorSpec( - [action_dim // 2, action_dim // 2] - ) + action_spec = MultiOneHot([action_dim // 2, action_dim // 2]) var_nums = action_spec.nvec elif action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) else: raise ValueError(f"Wrong {action_spec_type}") support = torch.linspace(vmin, vmax, atoms, dtype=torch.float) @@ -367,7 +358,7 @@ def _create_mock_distributional_actor( # if is_nn_module: # return module actor = DistributionalQValueActor( - spec=CompositeSpec( + spec=Composite( { "action": action_spec, action_value_key: None, @@ -776,7 +767,7 @@ def test_dqn_notensordict( ): n_obs = 3 n_action = 4 - action_spec = OneHotDiscreteTensorSpec(n_action) + action_spec = OneHot(n_action) module = nn.Linear(n_obs, n_action) # a simple value model actor = QValueActor( spec=action_spec, @@ -937,9 +928,9 @@ def _create_mock_actor( ): # Actor if action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) else: raise ValueError(f"Wrong {action_spec_type}") @@ -1386,7 +1377,7 @@ class TestDDPG(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Linear(obs_dim, action_dim) @@ -2024,7 +2015,7 @@ def _create_mock_actor( dropout=0.0, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Sequential( @@ -2376,7 +2367,7 @@ def test_td3_separate_losses( loss_fn = TD3Loss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + action_spec=Bounded(shape=(n_act,), low=-1, high=1), loss_function="l2", separate_losses=separate_losses, ) @@ -2730,7 +2721,7 @@ def _create_mock_actor( dropout=0.0, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Sequential( @@ -3089,7 +3080,7 @@ def test_td3bc_separate_losses( loss_fn = TD3BCLoss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + action_spec=Bounded(shape=(n_act,), low=-1, high=1), loss_function="l2", separate_losses=separate_losses, ) @@ -3456,7 +3447,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -3883,7 +3874,7 @@ def test_sac_separate_losses( loss_fn = SACLoss( actor_network=actor, qvalue_network=qvalue, - action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + action_spec=Unbounded(shape=(n_act,)), num_qvalue_nets=1, separate_losses=separate_losses, ) @@ -4287,14 +4278,14 @@ def test_state_dict(self, version): loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) state = loss.state_dict() loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -4302,7 +4293,7 @@ def test_state_dict(self, version): loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.target_entropy state = loss.state_dict() @@ -4310,7 +4301,7 @@ def test_state_dict(self, version): loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -4368,7 +4359,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( @@ -4954,7 +4945,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -5270,7 +5261,7 @@ def test_crossq_separate_losses( loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + action_spec=Unbounded(shape=(n_act,)), num_qvalue_nets=1, separate_losses=separate_losses, ) @@ -5575,14 +5566,14 @@ def test_state_dict( loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) state = loss.state_dict() loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -5590,7 +5581,7 @@ def test_state_dict( loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.target_entropy state = loss.state_dict() @@ -5598,7 +5589,7 @@ def test_state_dict( loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -5649,7 +5640,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -6594,7 +6585,7 @@ class TestCQL(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -7157,9 +7148,9 @@ def _create_mock_actor( ): # Actor if action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) else: raise ValueError(f"Wrong action spec type: {action_spec_type}") @@ -7167,7 +7158,7 @@ def _create_mock_actor( if is_nn_module: return module.to(device) actor = QValueActor( - spec=CompositeSpec( + spec=Composite( { "action": action_spec, ( @@ -7477,7 +7468,7 @@ def test_dcql_notensordict( ): n_obs = 3 n_action = 4 - action_spec = OneHotDiscreteTensorSpec(n_action) + action_spec = OneHot(n_action) module = nn.Linear(n_obs, n_action) # a simple value model actor = QValueActor( spec=action_spec, @@ -7552,7 +7543,7 @@ def _create_mock_actor( sample_log_prob_key="sample_log_prob", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -7588,7 +7579,7 @@ def _create_mock_value( def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) base_layer = nn.Linear(obs_dim, 5) @@ -7616,7 +7607,7 @@ def _create_mock_actor_value_shared( self, batch=2, obs_dim=3, action_dim=4, device="cpu" ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) base_layer = nn.Linear(obs_dim, 5) @@ -8443,7 +8434,7 @@ def _create_mock_actor( sample_log_prob_key="sample_log_prob", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -9152,7 +9143,7 @@ def test_reinforce_value_net( distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=UnboundedContinuousTensorSpec(n_act), + spec=Unbounded(n_act), ) if advantage == "gae": advantage = GAE( @@ -9262,7 +9253,7 @@ def test_reinforce_tensordict_keys(self, td_est): distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=UnboundedContinuousTensorSpec(n_act), + spec=Unbounded(n_act), ) loss_fn = ReinforceLoss( @@ -9456,7 +9447,7 @@ def test_reinforce_notensordict( distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=UnboundedContinuousTensorSpec(n_act), + spec=Unbounded(n_act), ) loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net) loss.set_keys( @@ -9632,8 +9623,8 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13 ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) ) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -9709,8 +9700,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=13): ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) ) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -9760,8 +9751,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13): ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) ) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -10050,7 +10041,7 @@ class TestOnlineDT(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -10282,7 +10273,7 @@ class TestDT(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -10696,7 +10687,7 @@ def _create_mock_actor( observation_key="observation", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -11507,7 +11498,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( diff --git a/test/test_env.py b/test/test_env.py index dee03c06e7d..b945498573d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -66,12 +66,7 @@ from torch import nn from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - NonTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, NonTensor, Unbounded from torchrl.envs import ( CatFrames, CatTensors, @@ -1908,18 +1903,12 @@ def test_info_dict_reader(self, device, seed=0): env.set_info_dict_reader( default_info_dict_reader( ["x_position"], - spec=CompositeSpec( - x_position=UnboundedContinuousTensorSpec( - dtype=torch.float64, shape=() - ) - ), + spec=Composite(x_position=Unbounded(dtype=torch.float64, shape=())), ) ) assert "x_position" in env.observation_spec.keys() - assert isinstance( - env.observation_spec["x_position"], UnboundedContinuousTensorSpec - ) + assert isinstance(env.observation_spec["x_position"], Unbounded) tensordict = env.reset() tensordict = env.rand_step(tensordict) @@ -1932,13 +1921,13 @@ def test_info_dict_reader(self, device, seed=0): ) for spec in ( - {"x_position": UnboundedContinuousTensorSpec((), dtype=torch.float64)}, + {"x_position": Unbounded((), dtype=torch.float64)}, # None, - CompositeSpec( - x_position=UnboundedContinuousTensorSpec((), dtype=torch.float64), + Composite( + x_position=Unbounded((), dtype=torch.float64), shape=[], ), - [UnboundedContinuousTensorSpec((), dtype=torch.float64)], + [Unbounded((), dtype=torch.float64)], ): env2 = GymWrapper(gym.make("HalfCheetah-v4")) env2.set_info_dict_reader( @@ -2079,7 +2068,7 @@ def main_penv(j, q=None): ], ) spec = env_p.action_spec - policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec.to(device))) + policy = TestConcurrentEnvs.Policy(Composite(action=spec.to(device))) N = 10 r_p = [] r_s = [] @@ -2113,7 +2102,7 @@ def main_collector(j, q=None): lambda i=i: CountingEnv(i, device=device) for i in range(j, j + n_workers) ] spec = make_envs[0]().action_spec - policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec)) + policy = TestConcurrentEnvs.Policy(Composite(action=spec)) collector = MultiSyncDataCollector( make_envs, policy, @@ -2225,7 +2214,7 @@ def test_nested_env(self, envclass): else: raise NotImplementedError reset = env.reset() - assert not isinstance(env.reward_spec, CompositeSpec) + assert not isinstance(env.reward_spec, Composite) for done_key in env.done_keys: assert ( env.full_done_spec[done_key] @@ -2496,8 +2485,8 @@ def test_mocking_envs(envclass): class TestTerminatedOrTruncated: @pytest.mark.parametrize("done_key", ["done", "terminated", "truncated"]) def test_root_prevail(self, done_key): - _spec = DiscreteTensorSpec(2, shape=(), dtype=torch.bool) - spec = CompositeSpec({done_key: _spec, ("agent", done_key): _spec}) + _spec = Categorical(2, shape=(), dtype=torch.bool) + spec = Composite({done_key: _spec, ("agent", done_key): _spec}) data = TensorDict({done_key: [False], ("agent", done_key): [True, False]}, []) assert not _terminated_or_truncated(data) assert not _terminated_or_truncated(data, full_done_spec=spec) @@ -2560,8 +2549,8 @@ def test_terminated_or_truncated_nospec(self): def test_terminated_or_truncated_spec(self): done_shape = (2, 1) nested_done_shape = (2, 3, 1) - spec = CompositeSpec( - done=DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool), + spec = Composite( + done=Categorical(2, shape=done_shape, dtype=torch.bool), shape=[ 2, ], @@ -2578,12 +2567,12 @@ def test_terminated_or_truncated_spec(self): ) assert data.get("_reset", None) is None - spec = CompositeSpec( + spec = Composite( { - ("agent", "done"): DiscreteTensorSpec( + ("agent", "done"): Categorical( 2, shape=nested_done_shape, dtype=torch.bool ), - ("nested", "done"): DiscreteTensorSpec( + ("nested", "done"): Categorical( 2, shape=nested_done_shape, dtype=torch.bool ), }, @@ -2618,11 +2607,11 @@ def test_terminated_or_truncated_spec(self): assert data["agent", "_reset"].shape == nested_done_shape assert data["nested", "_reset"].shape == nested_done_shape - spec = CompositeSpec( + spec = Composite( { - "truncated": DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool), - "terminated": DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool), - ("nested", "terminated"): DiscreteTensorSpec( + "truncated": Categorical(2, shape=done_shape, dtype=torch.bool), + "terminated": Categorical(2, shape=done_shape, dtype=torch.bool), + ("nested", "terminated"): Categorical( 2, shape=nested_done_shape, dtype=torch.bool ), }, @@ -2774,15 +2763,15 @@ def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td): class DifferentiableEnv(EnvBase): def __init__(self, device): super().__init__(device=device) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(3, device=device), + self.observation_spec = Composite( + observation=Unbounded(3, device=device), device=device, ) - self.action_spec = CompositeSpec( - action=UnboundedContinuousTensorSpec(3, device=device), device=device + self.action_spec = Composite( + action=Unbounded(3, device=device), device=device ) - self.reward_spec = CompositeSpec( - reward=UnboundedContinuousTensorSpec(1, device=device), device=device + self.reward_spec = Composite( + reward=Unbounded(1, device=device), device=device ) self.seed = 0 @@ -3283,7 +3272,7 @@ def _reset( return tensordict_reset def transform_observation_spec(self, observation_spec): - observation_spec["string"] = NonTensorSpec(()) + observation_spec["string"] = NonTensor(()) return observation_spec @pytest.mark.parametrize("batched", ["serial", "parallel"]) diff --git a/test/test_exploration.py b/test/test_exploration.py index b2fd97d986f..3bb05708d83 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -21,12 +21,7 @@ from torchrl._utils import _replace_last from torchrl.collectors import SyncDataCollector -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, -) +from torchrl.data import Bounded, Categorical, Composite, OneHot from torchrl.envs import SerialEnv from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv from torchrl.envs.utils import set_exploration_type @@ -59,7 +54,7 @@ class TestEGreedy: @set_exploration_type(InteractionType.RANDOM) def test_egreedy(self, eps_init, module): torch.manual_seed(0) - spec = BoundedTensorSpec(1, 1, torch.Size([4])) + spec = Bounded(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 4, bias=False) policy = Actor(spec=spec, module=module) @@ -91,9 +86,9 @@ def test_egreedy_masked(self, module, eps_init, spec_class): batch_size = (3, 4, 2) module = torch.nn.Linear(action_size, action_size, bias=False) if spec_class == "discrete": - spec = DiscreteTensorSpec(action_size) + spec = Categorical(action_size) else: - spec = OneHotDiscreteTensorSpec( + spec = OneHot( action_size, shape=(action_size,), ) @@ -166,7 +161,7 @@ def test_no_spec_error( action_size = 4 batch_size = (3, 4, 2) module = torch.nn.Linear(action_size, action_size, bias=False) - spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,)) + spec = OneHot(action_size, shape=(action_size,)) policy = QValueActor(spec=spec, module=module) explorative_policy = TensorDictSequential( policy, @@ -187,7 +182,7 @@ def test_no_spec_error( @pytest.mark.parametrize("module", [True, False]) def test_wrong_action_shape(self, module): torch.manual_seed(0) - spec = BoundedTensorSpec(1, 1, torch.Size([4])) + spec = Bounded(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 5, bias=False) policy = Actor(spec=spec, module=module) @@ -240,7 +235,7 @@ def test_ou( device ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) - action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) + action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( spec=action_spec, module=module, @@ -444,7 +439,7 @@ def test_additivegaussian_sd( pytest.skip("module raises an error if given spec=None") torch.manual_seed(seed) - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), @@ -463,9 +458,7 @@ def test_additivegaussian_sd( spec=None, ) policy = ProbabilisticActor( - spec=CompositeSpec(action=action_spec) - if spec_origin is not None - else None, + spec=Composite(action=action_spec) if spec_origin is not None else None, module=module, in_keys=["loc", "scale"], distribution_class=TanhNormal, @@ -541,7 +534,7 @@ def test_additivegaussian( device ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), @@ -670,7 +663,7 @@ def test_gsde( module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"low": -bound, "high": bound} - spec = BoundedTensorSpec( + spec = Bounded( -torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,) ).to(device) diff --git a/test/test_helpers.py b/test/test_helpers.py index f468eddf6ed..cf28252a318 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -30,7 +30,7 @@ MockSerialEnv, ) from packaging import version -from torchrl.data import BoundedTensorSpec, CompositeSpec +from torchrl.data import Bounded, Composite from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import ObservationNorm from torchrl.envs.transforms.transforms import ( @@ -259,17 +259,14 @@ def test_transformed_env_constructor_with_state_dict(from_pixels): def test_initialize_stats_from_observation_norms(device, keys, composed, initialized): obs_spec, stat_key = None, None if keys: - obs_spec = CompositeSpec( - **{ - key: BoundedTensorSpec(high=1, low=1, shape=torch.Size([1])) - for key in keys - } + obs_spec = Composite( + **{key: Bounded(high=1, low=1, shape=torch.Size([1])) for key in keys} ) stat_key = keys[0] env = ContinuousActionVecMockEnv( device=device, observation_spec=obs_spec, - action_spec=BoundedTensorSpec(low=1, high=2, shape=torch.Size((1,))), + action_spec=Bounded(low=1, high=2, shape=torch.Size((1,))), ) env.out_key = "observation" else: diff --git a/test/test_libs.py b/test/test_libs.py index 6ccbf2788a9..a76cb610d69 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -55,16 +55,16 @@ from torchrl._utils import implement_for, logger as torchrl_logger from torchrl.collectors.collectors import SyncDataCollector from torchrl.data import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Bounded, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + OneHot, ReplayBuffer, ReplayBufferEnsemble, - UnboundedContinuousTensorSpec, + Unbounded, UnboundedDiscreteTensorSpec, ) from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay @@ -206,18 +206,16 @@ def __init__(self, arg1, *, arg2, **kwargs): assert arg1 == 1 assert arg2 == 2 - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), - other=CompositeSpec( - another_other=UnboundedContinuousTensorSpec((*self.batch_size, 3)), + self.observation_spec = Composite( + observation=Unbounded((*self.batch_size, 3)), + other=Composite( + another_other=Unbounded((*self.batch_size, 3)), shape=self.batch_size, ), shape=self.batch_size, ) - self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) - self.done_spec = DiscreteTensorSpec( - 2, (*self.batch_size, 1), dtype=torch.bool - ) + self.action_spec = Unbounded((*self.batch_size, 3)) + self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone() def _reset(self, tensordict): @@ -242,16 +240,14 @@ def _set_seed(self, seed): @implement_for("gym", None, "0.18") def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape): - return CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)), - b=CompositeSpec( - c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size - ), + return Composite( + a=Unbounded(shape=(*batch_size, 1)), + b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size), d=cat(5, shape=cat_shape, dtype=torch.int64), e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64), - f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)), + f=Bounded(-3, 4, shape=(*batch_size, 1)), # g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long), - h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)), + h=Binary(n=5, shape=(*batch_size, 5)), shape=batch_size, ) @@ -259,16 +255,14 @@ def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape): def _make_spec( # noqa: F811 self, batch_size, cat, cat_shape, multicat, multicat_shape ): - return CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)), - b=CompositeSpec( - c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size - ), + return Composite( + a=Unbounded(shape=(*batch_size, 1)), + b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size), d=cat(5, shape=cat_shape, dtype=torch.int64), e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64), - f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)), + f=Bounded(-3, 4, shape=(*batch_size, 1)), g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long), - h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)), + h=Binary(n=5, shape=(*batch_size, 5)), shape=batch_size, ) @@ -276,27 +270,23 @@ def _make_spec( # noqa: F811 def _make_spec( # noqa: F811 self, batch_size, cat, cat_shape, multicat, multicat_shape ): - return CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)), - b=CompositeSpec( - c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size - ), + return Composite( + a=Unbounded(shape=(*batch_size, 1)), + b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size), d=cat(5, shape=cat_shape, dtype=torch.int64), e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64), - f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)), + f=Bounded(-3, 4, shape=(*batch_size, 1)), g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long), - h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)), + h=Binary(n=5, shape=(*batch_size, 5)), shape=batch_size, ) @pytest.mark.parametrize("categorical", [True, False]) def test_gym_spec_cast(self, categorical): batch_size = [3, 4] - cat = DiscreteTensorSpec if categorical else OneHotDiscreteTensorSpec + cat = Categorical if categorical else OneHot cat_shape = batch_size if categorical else (*batch_size, 5) - multicat = ( - MultiDiscreteTensorSpec if categorical else MultiOneHotDiscreteTensorSpec - ) + multicat = MultiCategorical if categorical else MultiOneHot multicat_shape = 2 if categorical else 5 spec = self._make_spec(batch_size, cat, cat_shape, multicat, multicat_shape) recon = _gym_to_torchrl_spec_transform( diff --git a/test/test_modules.py b/test/test_modules.py index 00e58678788..8966b61154c 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -16,7 +16,7 @@ from packaging import version from tensordict import TensorDict from torch import nn -from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec +from torchrl.data.tensor_specs import Bounded, Composite from torchrl.modules import ( CEMPlanner, DTActor, @@ -466,9 +466,7 @@ def test_dreamer_decoder( @pytest.mark.parametrize("deter_size", [20, 30]) @pytest.mark.parametrize("action_size", [3, 6]) def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size): - action_spec = BoundedTensorSpec( - shape=(action_size,), dtype=torch.float32, low=-1, high=1 - ) + action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) rssm_prior = RSSMPrior( action_spec, hidden_dim=stoch_size, @@ -521,9 +519,7 @@ def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size): def test_rssm_rollout( self, device, batch_size, temporal_size, stoch_size, deter_size, action_size ): - action_spec = BoundedTensorSpec( - shape=(action_size,), dtype=torch.float32, low=-1, high=1 - ) + action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) rssm_prior = RSSMPrior( action_spec, hidden_dim=stoch_size, @@ -650,10 +646,10 @@ def test_errors(self): ): TanhModule(in_keys=["a", "b"], out_keys=["a"]) with pytest.raises(ValueError, match=r"The minimum value \(-2\) provided"): - spec = BoundedTensorSpec(-1, 1, shape=()) + spec = Bounded(-1, 1, shape=()) TanhModule(in_keys=["act"], low=-2, spec=spec) with pytest.raises(ValueError, match=r"The maximum value \(-2\) provided to"): - spec = BoundedTensorSpec(-1, 1, shape=()) + spec = Bounded(-1, 1, shape=()) TanhModule(in_keys=["act"], high=-2, spec=spec) with pytest.raises(ValueError, match="Got high < low"): TanhModule(in_keys=["act"], high=-2, low=-1) @@ -709,12 +705,12 @@ def test_multi_inputs(self, out_keys, has_spec): if any(has_spec): spec = {} if has_spec[0]: - spec.update({real_out_keys[0]: BoundedTensorSpec(-2.0, 2.0, shape=())}) + spec.update({real_out_keys[0]: Bounded(-2.0, 2.0, shape=())}) low, high = -2.0, 2.0 if has_spec[1]: - spec.update({real_out_keys[1]: BoundedTensorSpec(-3.0, 3.0, shape=())}) + spec.update({real_out_keys[1]: Bounded(-3.0, 3.0, shape=())}) low, high = None, None - spec = CompositeSpec(spec) + spec = Composite(spec) else: spec = None low, high = -2.0, 2.0 diff --git a/test/test_specs.py b/test/test_specs.py index 2d597d770f0..82d2b7f2e1d 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse import contextlib +import warnings import numpy as np import pytest @@ -14,19 +15,32 @@ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple from torchrl._utils import _make_ordinal_device + from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, + Binary, BinaryDiscreteTensorSpec, + Bounded, BoundedTensorSpec, + Categorical, + Composite, CompositeSpec, + ContinuousBox, DiscreteTensorSpec, - LazyStackedCompositeSpec, + MultiCategorical, MultiDiscreteTensorSpec, + MultiOneHot, MultiOneHotDiscreteTensorSpec, + NonTensor, NonTensorSpec, + OneHot, OneHotDiscreteTensorSpec, + StackedComposite, TensorSpec, + Unbounded, + UnboundedContinuous, UnboundedContinuousTensorSpec, + UnboundedDiscrete, UnboundedDiscreteTensorSpec, ) from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec @@ -38,9 +52,7 @@ def test_bounded(dtype): np.random.seed(0) for _ in range(100): bounds = torch.randn(2).sort()[0] - ts = BoundedTensorSpec( - bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype - ) + ts = Bounded(bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype) _dtype = dtype if dtype is None: _dtype = torch.get_default_dtype() @@ -53,7 +65,7 @@ def test_bounded(dtype): assert (ts.encode(ts.to_numpy(r)) == r).all() -@pytest.mark.parametrize("cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec]) +@pytest.mark.parametrize("cls", [OneHot, Categorical]) def test_discrete(cls): torch.manual_seed(0) np.random.seed(0) @@ -78,7 +90,7 @@ def test_discrete(cls): def test_unbounded(dtype): torch.manual_seed(0) np.random.seed(0) - ts = UnboundedContinuousTensorSpec(dtype=dtype) + ts = Unbounded(dtype=dtype) if dtype is None: dtype = torch.get_default_dtype() @@ -99,7 +111,7 @@ def test_ndbounded(dtype, shape): for _ in range(100): lb = torch.rand(10) - 1 ub = torch.rand(10) + 1 - ts = BoundedTensorSpec(lb, ub, dtype=dtype) + ts = Bounded(lb, ub, dtype=dtype) _dtype = dtype if dtype is None: _dtype = torch.get_default_dtype() @@ -150,7 +162,7 @@ def test_ndunbounded(dtype, n, shape): torch.manual_seed(0) np.random.seed(0) - ts = UnboundedContinuousTensorSpec( + ts = Unbounded( shape=[ n, ], @@ -195,7 +207,7 @@ def test_binary(n, shape): torch.manual_seed(0) np.random.seed(0) - ts = BinaryDiscreteTensorSpec(n) + ts = Binary(n) for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( @@ -238,7 +250,7 @@ def test_binary(n, shape): def test_mult_onehot(shape, ns): torch.manual_seed(0) np.random.seed(0) - ts = MultiOneHotDiscreteTensorSpec(nvec=ns) + ts = MultiOneHot(nvec=ns) for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( @@ -279,7 +291,7 @@ def test_mult_onehot(shape, ns): def test_multi_discrete(shape, ns, dtype): torch.manual_seed(0) np.random.seed(0) - ts = MultiDiscreteTensorSpec(ns, dtype=dtype) + ts = MultiCategorical(ns, dtype=dtype) _real_shape = shape if shape is not None else [] nvec_shape = torch.tensor(ns).size() for _ in range(100): @@ -315,9 +327,9 @@ def test_multi_discrete(shape, ns, dtype): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) def test_discrete_conversion(n, device, shape): - categorical = DiscreteTensorSpec(n, device=device, shape=shape) + categorical = Categorical(n, device=device, shape=shape) shape_one_hot = [n] if not shape else [*shape, n] - one_hot = OneHotDiscreteTensorSpec(n, device=device, shape=shape_one_hot) + one_hot = OneHot(n, device=device, shape=shape_one_hot) assert categorical != one_hot assert categorical.to_one_hot_spec() == one_hot @@ -333,8 +345,8 @@ def test_discrete_conversion(n, device, shape): @pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) @pytest.mark.parametrize("device", get_default_devices()) def test_multi_discrete_conversion(ns, shape, device): - categorical = MultiDiscreteTensorSpec(ns, device=device) - one_hot = MultiOneHotDiscreteTensorSpec(ns, device=device) + categorical = MultiCategorical(ns, device=device) + one_hot = MultiOneHot(ns, device=device) assert categorical != one_hot assert categorical.to_one_hot_spec() == one_hot @@ -356,14 +368,14 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None): torch.manual_seed(0) np.random.seed(0) - return CompositeSpec( - obs=BoundedTensorSpec( + return Composite( + obs=Bounded( torch.zeros(*shape, 3, 32, 32), torch.ones(*shape, 3, 32, 32), dtype=dtype, device=device, ), - act=UnboundedContinuousTensorSpec( + act=Unbounded( ( *shape, 7, @@ -379,9 +391,9 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None): def test_getitem(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) - assert isinstance(ts["obs"], BoundedTensorSpec) + assert isinstance(ts["obs"], Bounded) if is_complete: - assert isinstance(ts["act"], UnboundedContinuousTensorSpec) + assert isinstance(ts["act"], Unbounded) else: assert ts["act"] is None with pytest.raises(KeyError): @@ -397,21 +409,17 @@ def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype): def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest): ts = self._composite_spec(shape, is_complete, device, dtype) - ts["good"] = UnboundedContinuousTensorSpec( - shape=shape, device=device, dtype=dtype - ) + ts["good"] = Unbounded(shape=shape, device=device, dtype=dtype) cm = ( contextlib.nullcontext() if (device == dest) or (device is None) else pytest.raises( - RuntimeError, match="All devices of CompositeSpec must match" + RuntimeError, match="All devices of Composite must match" ) ) with cm: # auto-casting is introduced since v0.3 - ts["bad"] = UnboundedContinuousTensorSpec( - shape=shape, device=dest, dtype=dtype - ) + ts["bad"] = Unbounded(shape=shape, device=dest, dtype=dtype) assert ts.device == device assert ts["good"].device == ( device if device is not None else torch.zeros(()).device @@ -490,7 +498,7 @@ def test_rand(self, shape, is_complete, device, dtype, shape_other): def test_repr(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) output = repr(ts) - assert output.startswith("CompositeSpec") + assert output.startswith("Composite") assert "obs: " in output assert "act: " in output @@ -606,7 +614,7 @@ def test_nested_composite_spec_delitem(self, shape, is_complete, device, dtype): def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec(new=None) + td2 = Composite(new=None) ts.update(td2) assert set(ts.keys(include_nested=True)) == { "obs", @@ -619,7 +627,7 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device)) + td2 = Composite(nested_cp=Composite(new=None).to(device)) ts.update(td2) assert set(ts.keys(include_nested=True)) == { "obs", @@ -632,7 +640,7 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device)) + td2 = Composite(nested_cp=Composite(act=None).to(device)) ts.update(td2) assert set(ts.keys(include_nested=True)) == { "obs", @@ -645,13 +653,13 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec( - nested_cp=CompositeSpec(act=None, shape=shape).to(device), shape=shape + td2 = Composite( + nested_cp=Composite(act=None, shape=shape).to(device), shape=shape ) ts.update(td2) - td2 = CompositeSpec( - nested_cp=CompositeSpec( - act=UnboundedContinuousTensorSpec(shape=shape, device=device), + td2 = Composite( + nested_cp=Composite( + act=Unbounded(shape=shape, device=device), shape=shape, ), shape=shape, @@ -668,8 +676,8 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): def test_change_batch_size(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) - ts["nested"] = CompositeSpec( - leaf=UnboundedContinuousTensorSpec(shape, device=device), + ts["nested"] = Composite( + leaf=Unbounded(shape, device=device), shape=shape, device=device, ) @@ -690,12 +698,12 @@ def test_change_batch_size(self, shape, is_complete, device, dtype): @pytest.mark.parametrize("device", get_default_devices()) def test_create_composite_nested(shape, device): d = [ - {("a", "b"): UnboundedContinuousTensorSpec(shape=shape, device=device)}, - {"a": {"b": UnboundedContinuousTensorSpec(shape=shape, device=device)}}, + {("a", "b"): Unbounded(shape=shape, device=device)}, + {"a": {"b": Unbounded(shape=shape, device=device)}}, ] for _d in d: - c = CompositeSpec(_d, shape=shape) - assert isinstance(c["a", "b"], UnboundedContinuousTensorSpec) + c = Composite(_d, shape=shape) + assert isinstance(c["a", "b"], Unbounded) assert c["a"].shape == torch.Size(shape) assert c.device is None # device not explicitly passed assert c["a"].device is None # device not explicitly passed @@ -708,10 +716,8 @@ def test_create_composite_nested(shape, device): @pytest.mark.parametrize("recurse", [True, False]) def test_lock(recurse): shape = [3, 4, 5] - spec = CompositeSpec( - a=CompositeSpec( - b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2] - ), + spec = Composite( + a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]), shape=shape[:1], ) spec["a"] = spec["a"].clone() @@ -719,15 +725,15 @@ def test_lock(recurse): assert not spec.locked spec.lock_(recurse=recurse) assert spec.locked - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec["a"] = spec["a"].clone() - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec.set("a", spec["a"].clone()) if recurse: assert spec["a"].locked - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec["a"].set("b", spec["a", "b"].clone()) - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec["a", "b"] = spec["a", "b"].clone() else: assert not spec["a"].locked @@ -763,33 +769,25 @@ def test_equality_bounded(self): device = "cpu" dtype = torch.float16 - ts = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype) + ts = Bounded(minimum, maximum, torch.Size((1,)), device, dtype) - ts_same = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype) + ts_same = Bounded(minimum, maximum, torch.Size((1,)), device, dtype) assert ts == ts_same - ts_other = BoundedTensorSpec( - minimum + 1, maximum, torch.Size((1,)), device, dtype - ) + ts_other = Bounded(minimum + 1, maximum, torch.Size((1,)), device, dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( - minimum, maximum + 1, torch.Size((1,)), device, dtype - ) + ts_other = Bounded(minimum, maximum + 1, torch.Size((1,)), device, dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = BoundedTensorSpec( - minimum, maximum, torch.Size((1,)), "cuda:0", dtype - ) + ts_other = Bounded(minimum, maximum, torch.Size((1,)), "cuda:0", dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( - minimum, maximum, torch.Size((1,)), device, torch.float64 - ) + ts_other = Bounded(minimum, maximum, torch.Size((1,)), device, torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -799,38 +797,34 @@ def test_equality_onehot(self): dtype = torch.float16 use_register = False - ts = OneHotDiscreteTensorSpec( - n=n, device=device, dtype=dtype, use_register=use_register - ) + ts = OneHot(n=n, device=device, dtype=dtype, use_register=use_register) - ts_same = OneHotDiscreteTensorSpec( - n=n, device=device, dtype=dtype, use_register=use_register - ) + ts_same = OneHot(n=n, device=device, dtype=dtype, use_register=use_register) assert ts == ts_same - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n + 1, device=device, dtype=dtype, use_register=use_register ) assert ts != ts_other if torch.cuda.device_count(): - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n, device="cuda:0", dtype=dtype, use_register=use_register ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n, device=device, dtype=torch.float64, use_register=use_register ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n, device=device, dtype=dtype, use_register=not use_register ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -838,21 +832,25 @@ def test_equality_unbounded(self): device = "cpu" dtype = torch.float16 - ts = UnboundedContinuousTensorSpec(device=device, dtype=dtype) + ts = Unbounded(device=device, dtype=dtype) - ts_same = UnboundedContinuousTensorSpec(device=device, dtype=dtype) + ts_same = Unbounded(device=device, dtype=dtype) assert ts == ts_same if torch.cuda.device_count(): - ts_other = UnboundedContinuousTensorSpec(device="cuda:0", dtype=dtype) + ts_other = Unbounded(device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = UnboundedContinuousTensorSpec(device=device, dtype=torch.float64) + ts_other = Unbounded(device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts + ) + ts_other.space = ContinuousBox( + ts_other.space.low * 0, ts_other.space.high * 0 + 1 ) + assert ts.space != ts_other.space, (ts.space, ts_other.space) assert ts != ts_other def test_equality_ndbounded(self): @@ -861,36 +859,28 @@ def test_equality_ndbounded(self): device = "cpu" dtype = torch.float16 - ts = BoundedTensorSpec(low=minimum, high=maximum, device=device, dtype=dtype) + ts = Bounded(low=minimum, high=maximum, device=device, dtype=dtype) - ts_same = BoundedTensorSpec( - low=minimum, high=maximum, device=device, dtype=dtype - ) + ts_same = Bounded(low=minimum, high=maximum, device=device, dtype=dtype) assert ts == ts_same - ts_other = BoundedTensorSpec( - low=minimum + 1, high=maximum, device=device, dtype=dtype - ) + ts_other = Bounded(low=minimum + 1, high=maximum, device=device, dtype=dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( - low=minimum, high=maximum + 1, device=device, dtype=dtype - ) + ts_other = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = BoundedTensorSpec( - low=minimum, high=maximum, device="cuda:0", dtype=dtype - ) + ts_other = Bounded(low=minimum, high=maximum, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( + ts_other = Bounded( low=minimum, high=maximum, device=device, dtype=torch.float64 ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -900,32 +890,28 @@ def test_equality_discrete(self): device = "cpu" dtype = torch.float16 - ts = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype) + ts = Categorical(n=n, shape=shape, device=device, dtype=dtype) - ts_same = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype) + ts_same = Categorical(n=n, shape=shape, device=device, dtype=dtype) assert ts == ts_same - ts_other = DiscreteTensorSpec(n=n + 1, shape=shape, device=device, dtype=dtype) + ts_other = Categorical(n=n + 1, shape=shape, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = DiscreteTensorSpec( - n=n, shape=shape, device="cuda:0", dtype=dtype - ) + ts_other = Categorical(n=n, shape=shape, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = DiscreteTensorSpec( - n=n, shape=shape, device=device, dtype=torch.float64 - ) + ts_other = Categorical(n=n, shape=shape, device=device, dtype=torch.float64) assert ts != ts_other - ts_other = DiscreteTensorSpec( + ts_other = Categorical( n=n, shape=torch.Size([2]), device=device, dtype=torch.float64 ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -941,30 +927,24 @@ def test_equality_ndunbounded(self, shape): device = "cpu" dtype = torch.float16 - ts = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) + ts = Unbounded(shape=shape, device=device, dtype=dtype) - ts_same = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) + ts_same = Unbounded(shape=shape, device=device, dtype=dtype) assert ts == ts_same - other_shape = 13 if type(shape) == int else torch.Size(np.array(shape) + 10) - ts_other = UnboundedContinuousTensorSpec( - shape=other_shape, device=device, dtype=dtype - ) + other_shape = 13 if isinstance(shape, int) else torch.Size(np.array(shape) + 10) + ts_other = Unbounded(shape=other_shape, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = UnboundedContinuousTensorSpec( - shape=shape, device="cuda:0", dtype=dtype - ) + ts_other = Unbounded(shape=shape, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = UnboundedContinuousTensorSpec( - shape=shape, device=device, dtype=torch.float64 - ) + ts_other = Unbounded(shape=shape, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) # Unbounded and bounded without space are technically the same assert ts == ts_other @@ -974,23 +954,23 @@ def test_equality_binary(self): device = "cpu" dtype = torch.float16 - ts = BinaryDiscreteTensorSpec(n=n, device=device, dtype=dtype) + ts = Binary(n=n, device=device, dtype=dtype) - ts_same = BinaryDiscreteTensorSpec(n=n, device=device, dtype=dtype) + ts_same = Binary(n=n, device=device, dtype=dtype) assert ts == ts_same - ts_other = BinaryDiscreteTensorSpec(n=n + 5, device=device, dtype=dtype) + ts_other = Binary(n=n + 5, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = BinaryDiscreteTensorSpec(n=n, device="cuda:0", dtype=dtype) + ts_other = Binary(n=n, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = BinaryDiscreteTensorSpec(n=n, device=device, dtype=torch.float64) + ts_other = Binary(n=n, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -999,42 +979,32 @@ def test_equality_multi_onehot(self, nvec): device = "cpu" dtype = torch.float16 - ts = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts = MultiOneHot(nvec=nvec, device=device, dtype=dtype) - ts_same = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts_same = MultiOneHot(nvec=nvec, device=device, dtype=dtype) assert ts == ts_same other_nvec = np.array(nvec) + 3 - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=other_nvec, device=device, dtype=dtype - ) + ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12] - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=other_nvec, device=device, dtype=dtype - ) + ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12, 13] - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=other_nvec, device=device, dtype=dtype - ) + ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=nvec, device="cuda:0", dtype=dtype - ) + ts_other = MultiOneHot(nvec=nvec, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=nvec, device=device, dtype=torch.float64 - ) + ts_other = MultiOneHot(nvec=nvec, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -1043,34 +1013,32 @@ def test_equality_multi_discrete(self, nvec): device = "cpu" dtype = torch.float16 - ts = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts = MultiCategorical(nvec=nvec, device=device, dtype=dtype) - ts_same = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts_same = MultiCategorical(nvec=nvec, device=device, dtype=dtype) assert ts == ts_same other_nvec = np.array(nvec) + 3 - ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12] - ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12, 13] - ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cuda:0", dtype=dtype) + ts_other = MultiCategorical(nvec=nvec, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = MultiDiscreteTensorSpec( - nvec=nvec, device=device, dtype=torch.float64 - ) + ts_other = MultiCategorical(nvec=nvec, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -1080,69 +1048,63 @@ def test_equality_composite(self): device = "cpu" dtype = torch.float16 - bounded = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype) - bounded_same = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype) - bounded_other = BoundedTensorSpec(0, 2, torch.Size((1,)), device, dtype) + bounded = Bounded(0, 1, torch.Size((1,)), device, dtype) + bounded_same = Bounded(0, 1, torch.Size((1,)), device, dtype) + bounded_other = Bounded(0, 2, torch.Size((1,)), device, dtype) - nd = BoundedTensorSpec( - low=minimum, high=maximum + 1, device=device, dtype=dtype - ) - nd_same = BoundedTensorSpec( - low=minimum, high=maximum + 1, device=device, dtype=dtype - ) - _ = BoundedTensorSpec(low=minimum, high=maximum + 3, device=device, dtype=dtype) + nd = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype) + nd_same = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype) + _ = Bounded(low=minimum, high=maximum + 3, device=device, dtype=dtype) # Equality tests - ts = CompositeSpec(ts1=bounded) - ts_same = CompositeSpec(ts1=bounded) + ts = Composite(ts1=bounded) + ts_same = Composite(ts1=bounded) assert ts == ts_same - ts = CompositeSpec(ts1=bounded) - ts_same = CompositeSpec(ts1=bounded_same) + ts = Composite(ts1=bounded) + ts_same = Composite(ts1=bounded_same) assert ts == ts_same - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_same = CompositeSpec(ts1=bounded, ts2=nd) + ts = Composite(ts1=bounded, ts2=nd) + ts_same = Composite(ts1=bounded, ts2=nd) assert ts == ts_same - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_same = CompositeSpec(ts1=bounded_same, ts2=nd_same) + ts = Composite(ts1=bounded, ts2=nd) + ts_same = Composite(ts1=bounded_same, ts2=nd_same) assert ts == ts_same - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_same = CompositeSpec(ts2=nd_same, ts1=bounded_same) + ts = Composite(ts1=bounded, ts2=nd) + ts_same = Composite(ts2=nd_same, ts1=bounded_same) assert ts == ts_same # Inequality tests - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts5=bounded) + ts = Composite(ts1=bounded) + ts_other = Composite(ts5=bounded) assert ts != ts_other - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts1=bounded_other) + ts = Composite(ts1=bounded) + ts_other = Composite(ts1=bounded_other) assert ts != ts_other - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts1=nd) + ts = Composite(ts1=bounded) + ts_other = Composite(ts1=nd) assert ts != ts_other - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts1=bounded, ts2=nd) + ts = Composite(ts1=bounded) + ts_other = Composite(ts1=bounded, ts2=nd) assert ts != ts_other - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_other = CompositeSpec(ts2=nd) + ts = Composite(ts1=bounded, ts2=nd) + ts_other = Composite(ts2=nd) assert ts != ts_other - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_other = CompositeSpec(ts1=bounded, ts2=nd, ts3=bounded_other) + ts = Composite(ts1=bounded, ts2=nd) + ts_other = Composite(ts1=bounded, ts2=nd, ts3=bounded_other) assert ts != ts_other class TestSpec: - @pytest.mark.parametrize( - "action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec] - ) + @pytest.mark.parametrize("action_spec_cls", [OneHot, Categorical]) def test_discrete_action_spec_reconstruct(self, action_spec_cls): torch.manual_seed(0) action_spec = action_spec_cls(10) @@ -1161,7 +1123,7 @@ def test_discrete_action_spec_reconstruct(self, action_spec_cls): def test_mult_discrete_action_spec_reconstruct(self): torch.manual_seed(0) - action_spec = MultiOneHotDiscreteTensorSpec((10, 5)) + action_spec = MultiOneHot((10, 5)) actions_tensors = [action_spec.rand() for _ in range(10)] actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors] @@ -1183,7 +1145,7 @@ def test_mult_discrete_action_spec_reconstruct(self): def test_one_hot_discrete_action_spec_rand(self): torch.manual_seed(0) - action_spec = OneHotDiscreteTensorSpec(10) + action_spec = OneHot(10) sample = action_spec.rand((100000,)) @@ -1197,7 +1159,7 @@ def test_one_hot_discrete_action_spec_rand(self): def test_categorical_action_spec_rand(self): torch.manual_seed(1) - action_spec = DiscreteTensorSpec(10) + action_spec = Categorical(10) sample = action_spec.rand((10000,)) @@ -1213,7 +1175,7 @@ def test_mult_discrete_action_spec_rand(self): torch.manual_seed(0) ns = (10, 5) N = 100000 - action_spec = MultiOneHotDiscreteTensorSpec((10, 5)) + action_spec = MultiOneHot((10, 5)) actions_tensors = [action_spec.rand() for _ in range(10)] actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors] @@ -1238,7 +1200,7 @@ def test_mult_discrete_action_spec_rand(self): assert chisquare(sample_list).pvalue > 0.1 def test_categorical_action_spec_encode(self): - action_spec = DiscreteTensorSpec(10) + action_spec = Categorical(10) projected = action_spec.project( torch.tensor([-100, -1, 0, 1, 9, 10, 100], dtype=torch.long) @@ -1255,12 +1217,12 @@ def test_categorical_action_spec_encode(self): ).all() def test_bounded_rand(self): - spec = BoundedTensorSpec(-3, 3, torch.Size((1,))) + spec = Bounded(-3, 3, torch.Size((1,))) sample = torch.stack([spec.rand() for _ in range(100)]) assert (-3 <= sample).all() and (3 >= sample).all() def test_ndbounded_shape(self): - spec = BoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5]) + spec = Bounded(-3, 3 * torch.ones(10, 5), shape=[10, 5]) sample = torch.stack([spec.rand() for _ in range(100)], 0) assert (-3 <= sample).all() and (3 >= sample).all() assert sample.shape == torch.Size([100, 10, 5]) @@ -1270,9 +1232,7 @@ class TestExpand: @pytest.mark.parametrize("shape1", [None, (4,), (5, 4)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_binary(self, shape1, shape2): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1304,9 +1264,7 @@ def test_binary(self, shape1, shape2): ], ) def test_bounded(self, shape1, shape2, mini, maxi): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) shape1 = spec.shape assert shape1 == torch.Size([10]) shape2_real = (*shape2, *shape1) @@ -1326,7 +1284,7 @@ def test_bounded(self, shape1, shape2, mini, maxi): def test_composite(self): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1336,22 +1294,16 @@ def test_composite(self): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1361,7 +1313,7 @@ def test_composite(self): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -1392,7 +1344,7 @@ def test_composite(self): @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_discrete(self, shape1, shape2): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1418,7 +1370,7 @@ def test_multidiscrete(self, shape1, shape2): shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) if shape1 is not None: @@ -1446,9 +1398,7 @@ def test_multionehot(self, shape1, shape2): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1468,11 +1418,11 @@ def test_multionehot(self, shape1, shape2): assert spec2.zero().shape == spec2.shape def test_non_tensor(self): - spec = NonTensorSpec((3, 4), device="cpu") + spec = NonTensor((3, 4), device="cpu") assert ( spec.expand(2, 3, 4) == spec.expand((2, 3, 4)) - == NonTensorSpec((2, 3, 4), device="cpu") + == NonTensor((2, 3, 4), device="cpu") ) @pytest.mark.parametrize("shape1", [None, (), (5,)]) @@ -1482,9 +1432,7 @@ def test_onehot(self, shape1, shape2): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1510,9 +1458,7 @@ def test_unbounded(self, shape1, shape2): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1571,9 +1517,7 @@ class TestClone: ], ) def test_binary(self, shape1): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) assert spec == spec.clone() assert spec is not spec.clone() @@ -1589,15 +1533,13 @@ def test_binary(self, shape1): ], ) def test_bounded(self, shape1, mini, maxi): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) assert spec == spec.clone() assert spec is not spec.clone() def test_composite(self): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1607,22 +1549,16 @@ def test_composite(self): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1632,7 +1568,7 @@ def test_composite(self): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -1654,7 +1590,7 @@ def test_discrete( self, shape1, ): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() @@ -1667,7 +1603,7 @@ def test_multidiscrete( shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) assert spec == spec.clone() @@ -1682,14 +1618,12 @@ def test_multionehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() def test_non_tensor(self): - spec = NonTensorSpec(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu") assert spec.clone() == spec assert spec.clone() is not spec @@ -1702,9 +1636,7 @@ def test_onehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() @@ -1717,9 +1649,7 @@ def test_unbounded( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) assert spec == spec.clone() assert spec is not spec.clone() @@ -1740,9 +1670,7 @@ def test_unboundeddiscrete( class TestUnbind: @pytest.mark.parametrize("shape1", [(5, 4)]) def test_binary(self, shape1): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) @@ -1759,16 +1687,14 @@ def test_binary(self, shape1): ], ) def test_bounded(self, shape1, mini, maxi): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) def test_composite(self): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1778,22 +1704,16 @@ def test_composite(self): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1803,7 +1723,7 @@ def test_composite(self): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -1822,7 +1742,7 @@ def test_discrete( self, shape1, ): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) assert spec == torch.stack(spec.unbind(0), 0) assert spec == torch.stack(spec.unbind(-1), -1) @@ -1835,7 +1755,7 @@ def test_multidiscrete( shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) assert spec == torch.stack(spec.unbind(0), 0) @@ -1851,15 +1771,13 @@ def test_multionehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) def test_non_tensor(self): - spec = NonTensorSpec(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu") assert spec.unbind(1)[0] == spec[:, 0] assert spec.unbind(1)[0] is not spec[:, 0] @@ -1872,9 +1790,7 @@ def test_onehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) @@ -1888,9 +1804,7 @@ def test_unbounded( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) assert spec == torch.stack(spec.unbind(0), 0) assert spec == torch.stack(spec.unbind(-1), -1) @@ -1908,15 +1822,15 @@ def test_unboundeddiscrete( assert spec == torch.stack(spec.unbind(-1), -1) def test_composite_encode_err(self): - c = CompositeSpec( - a=UnboundedContinuousTensorSpec( + c = Composite( + a=Unbounded( 1, ), - b=UnboundedContinuousTensorSpec( + b=Unbounded( 2, ), ) - with pytest.raises(KeyError, match="The CompositeSpec instance with keys"): + with pytest.raises(KeyError, match="The Composite instance with keys"): c.encode({"c": 0}) with pytest.raises( RuntimeError, match="raised a RuntimeError. Scroll up to know more" @@ -1932,9 +1846,7 @@ def test_composite_encode_err(self): class TestTo: @pytest.mark.parametrize("shape1", [(5, 4)]) def test_binary(self, shape1, device): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) assert spec.to(device).device == device @pytest.mark.parametrize( @@ -1949,14 +1861,12 @@ def test_binary(self, shape1, device): ], ) def test_bounded(self, shape1, mini, maxi, device): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) assert spec.to(device).device == device def test_composite(self, device): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1966,22 +1876,16 @@ def test_composite(self, device): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1991,7 +1895,7 @@ def test_composite(self, device): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -2010,7 +1914,7 @@ def test_discrete( shape1, device, ): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2019,7 +1923,7 @@ def test_multidiscrete(self, shape1, device): shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) assert spec.to(device).device == device @@ -2030,13 +1934,11 @@ def test_multionehot(self, shape1, device): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) assert spec.to(device).device == device def test_non_tensor(self, device): - spec = NonTensorSpec(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu") assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2045,9 +1947,7 @@ def test_onehot(self, shape1, device): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2056,9 +1956,7 @@ def test_unbounded(self, shape1, device): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2079,10 +1977,10 @@ class TestStack: def test_stack_binarydiscrete(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, BinaryDiscreteTensorSpec) + assert isinstance(c, Binary) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2092,7 +1990,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim): def test_stack_binarydiscrete_expand(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2105,7 +2003,7 @@ def test_stack_binarydiscrete_expand(self, shape, stack_dim): def test_stack_binarydiscrete_rand(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2114,7 +2012,7 @@ def test_stack_binarydiscrete_rand(self, shape, stack_dim): def test_stack_binarydiscrete_zero(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2124,10 +2022,10 @@ def test_stack_bounded(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, BoundedTensorSpec) + assert isinstance(c, Bounded) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2138,7 +2036,7 @@ def test_stack_bounded_expand(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2152,7 +2050,7 @@ def test_stack_bounded_rand(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2162,7 +2060,7 @@ def test_stack_bounded_zero(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2171,10 +2069,10 @@ def test_stack_bounded_zero(self, shape, stack_dim): def test_stack_discrete(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, DiscreteTensorSpec) + assert isinstance(c, Categorical) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2184,7 +2082,7 @@ def test_stack_discrete(self, shape, stack_dim): def test_stack_discrete_expand(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2197,7 +2095,7 @@ def test_stack_discrete_expand(self, shape, stack_dim): def test_stack_discrete_rand(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2206,7 +2104,7 @@ def test_stack_discrete_rand(self, shape, stack_dim): def test_stack_discrete_zero(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2215,10 +2113,10 @@ def test_stack_discrete_zero(self, shape, stack_dim): def test_stack_multidiscrete(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, MultiDiscreteTensorSpec) + assert isinstance(c, MultiCategorical) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2228,7 +2126,7 @@ def test_stack_multidiscrete(self, shape, stack_dim): def test_stack_multidiscrete_expand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2241,7 +2139,7 @@ def test_stack_multidiscrete_expand(self, shape, stack_dim): def test_stack_multidiscrete_rand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2250,7 +2148,7 @@ def test_stack_multidiscrete_rand(self, shape, stack_dim): def test_stack_multidiscrete_zero(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2259,10 +2157,10 @@ def test_stack_multidiscrete_zero(self, shape, stack_dim): def test_stack_multionehot(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, MultiOneHotDiscreteTensorSpec) + assert isinstance(c, MultiOneHot) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2272,7 +2170,7 @@ def test_stack_multionehot(self, shape, stack_dim): def test_stack_multionehot_expand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2285,7 +2183,7 @@ def test_stack_multionehot_expand(self, shape, stack_dim): def test_stack_multionehot_rand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2294,15 +2192,15 @@ def test_stack_multionehot_rand(self, shape, stack_dim): def test_stack_multionehot_zero(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() assert r.shape == c.shape def test_stack_non_tensor(self, shape, stack_dim): - spec0 = NonTensorSpec(shape=shape, device="cpu") - spec1 = NonTensorSpec(shape=shape, device="cpu") + spec0 = NonTensor(shape=shape, device="cpu") + spec1 = NonTensor(shape=shape, device="cpu") new_spec = torch.stack([spec0, spec1], stack_dim) shape_insert = list(shape) shape_insert.insert(stack_dim, 2) @@ -2312,10 +2210,10 @@ def test_stack_non_tensor(self, shape, stack_dim): def test_stack_onehot(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, OneHotDiscreteTensorSpec) + assert isinstance(c, OneHot) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2325,7 +2223,7 @@ def test_stack_onehot(self, shape, stack_dim): def test_stack_onehot_expand(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2338,7 +2236,7 @@ def test_stack_onehot_expand(self, shape, stack_dim): def test_stack_onehot_rand(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2347,7 +2245,7 @@ def test_stack_onehot_rand(self, shape, stack_dim): def test_stack_onehot_zero(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2355,10 +2253,10 @@ def test_stack_onehot_zero(self, shape, stack_dim): def test_stack_unboundedcont(self, shape, stack_dim): shape = (*shape,) - c1 = UnboundedContinuousTensorSpec(shape=shape) + c1 = Unbounded(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, UnboundedContinuousTensorSpec) + assert isinstance(c, Unbounded) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2367,7 +2265,7 @@ def test_stack_unboundedcont(self, shape, stack_dim): def test_stack_unboundedcont_expand(self, shape, stack_dim): shape = (*shape,) - c1 = UnboundedContinuousTensorSpec(shape=shape) + c1 = Unbounded(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2379,7 +2277,7 @@ def test_stack_unboundedcont_expand(self, shape, stack_dim): def test_stack_unboundedcont_rand(self, shape, stack_dim): shape = (*shape,) - c1 = UnboundedContinuousTensorSpec(shape=shape) + c1 = Unbounded(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2434,8 +2332,8 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): assert r.shape == c.shape def test_to_numpy(self, shape, stack_dim): - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) - c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float64) + c2 = Bounded(-1, 1, shape=shape, dtype=torch.float64) c = torch.stack([c1, c2], stack_dim) @@ -2455,13 +2353,13 @@ def test_to_numpy(self, shape, stack_dim): c.to_numpy(val + 1, safe=True) def test_malformed_stack(self, shape, stack_dim): - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) - c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float64) + c2 = Bounded(-1, 1, shape=shape, dtype=torch.float32) with pytest.raises(RuntimeError, match="Dtypes differ"): torch.stack([c1, c2], stack_dim) - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) - c2 = UnboundedContinuousTensorSpec(shape=shape, dtype=torch.float32) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float32) + c2 = Unbounded(shape=shape, dtype=torch.float32) c3 = UnboundedDiscreteTensorSpec(shape=shape, dtype=torch.float32) with pytest.raises( RuntimeError, @@ -2470,40 +2368,40 @@ def test_malformed_stack(self, shape, stack_dim): torch.stack([c1, c2], stack_dim) torch.stack([c3, c2], stack_dim) - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) - c2 = BoundedTensorSpec(-1, 1, shape=shape + (3,), dtype=torch.float32) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float32) + c2 = Bounded(-1, 1, shape=shape + (3,), dtype=torch.float32) with pytest.raises(RuntimeError, match="Ndims differ"): torch.stack([c1, c2], stack_dim) -class TestDenseStackedCompositeSpecs: +class TestDenseStackedComposite: def test_stack(self): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) + c1 = Composite(a=Unbounded()) c2 = c1.clone() c = torch.stack([c1, c2], 0) - assert isinstance(c, CompositeSpec) + assert isinstance(c, Composite) -class TestLazyStackedCompositeSpecs: +class TestLazyStackedComposite: def _get_heterogeneous_specs( self, batch_size=(), stack_dim: int = 0, ): - shared = BoundedTensorSpec(low=0, high=1, shape=(*batch_size, 32, 32, 3)) - hetero_3d = UnboundedContinuousTensorSpec( + shared = Bounded(low=0, high=1, shape=(*batch_size, 32, 32, 3)) + hetero_3d = Unbounded( shape=( *batch_size, 3, ) ) - hetero_2d = UnboundedContinuousTensorSpec( + hetero_2d = Unbounded( shape=( *batch_size, 2, ) ) - lidar = BoundedTensorSpec( + lidar = Bounded( low=0, high=5, shape=( @@ -2512,9 +2410,9 @@ def _get_heterogeneous_specs( ), ) - individual_0_obs = CompositeSpec( + individual_0_obs = Composite( { - "individual_0_obs_0": UnboundedContinuousTensorSpec( + "individual_0_obs_0": Unbounded( shape=( *batch_size, 3, @@ -2524,25 +2422,21 @@ def _get_heterogeneous_specs( }, shape=(*batch_size, 3), ) - individual_1_obs = CompositeSpec( + individual_1_obs = Composite( { - "individual_1_obs_0": BoundedTensorSpec( + "individual_1_obs_0": Bounded( low=0, high=3, shape=(*batch_size, 3, 1, 2) ) }, shape=(*batch_size, 3), ) - individual_2_obs = CompositeSpec( - { - "individual_1_obs_0": UnboundedContinuousTensorSpec( - shape=(*batch_size, 3, 1, 2, 3) - ) - }, + individual_2_obs = Composite( + {"individual_1_obs_0": Unbounded(shape=(*batch_size, 3, 1, 2, 3))}, shape=(*batch_size, 3), ) spec_list = [ - CompositeSpec( + Composite( { "shared": shared, "lidar": lidar, @@ -2551,7 +2445,7 @@ def _get_heterogeneous_specs( }, shape=batch_size, ), - CompositeSpec( + Composite( { "shared": shared, "lidar": lidar, @@ -2560,7 +2454,7 @@ def _get_heterogeneous_specs( }, shape=batch_size, ), - CompositeSpec( + Composite( { "shared": shared, "hetero": hetero_2d, @@ -2573,10 +2467,8 @@ def _get_heterogeneous_specs( return torch.stack(spec_list, dim=stack_dim).cpu() def test_stack_index(self): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec() - ) + c1 = Composite(a=Unbounded()) + c2 = Composite(a=Unbounded(), b=UnboundedDiscreteTensorSpec()) c = torch.stack([c1, c2], 0) assert c.shape == torch.Size([2]) assert c[0] is c1 @@ -2585,19 +2477,19 @@ def test_stack_index(self): assert c[..., 1] is c2 assert c[0, ...] is c1 assert c[1, ...] is c2 - assert isinstance(c[:], LazyStackedCompositeSpec) + assert isinstance(c[:], StackedComposite) @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_index_multdim(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) c = torch.stack([c1, c2], stack_dim) if stack_dim in (0, -3): - assert isinstance(c[:], LazyStackedCompositeSpec) + assert isinstance(c[:], StackedComposite) assert c.shape == torch.Size([2, 1, 3]) assert c[0] is c1 assert c[1] is c2 @@ -2614,7 +2506,7 @@ def test_stack_index_multdim(self, stack_dim): assert c[0, ...] is c1 assert c[1, ...] is c2 elif stack_dim == (1, -2): - assert isinstance(c[:, :], LazyStackedCompositeSpec) + assert isinstance(c[:, :], StackedComposite) assert c.shape == torch.Size([1, 2, 3]) assert c[:, 0] is c1 assert c[:, 1] is c2 @@ -2641,7 +2533,7 @@ def test_stack_index_multdim(self, stack_dim): assert c[:, 0, ...] is c1 assert c[:, 1, ...] is c2 elif stack_dim == (2, -1): - assert isinstance(c[:, :, :], LazyStackedCompositeSpec) + assert isinstance(c[:, :, :], StackedComposite) with pytest.raises( IndexError, match="along dimension 0 when the stack dimension is 2." ): @@ -2660,9 +2552,9 @@ def test_stack_index_multdim(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_expand_multi(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2691,9 +2583,9 @@ def test_stack_expand_multi(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_rand(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2713,9 +2605,9 @@ def test_stack_rand(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_rand_shape(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2736,9 +2628,9 @@ def test_stack_rand_shape(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_zero(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2758,9 +2650,9 @@ def test_stack_zero(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_zero_shape(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2782,14 +2674,14 @@ def test_stack_zero_shape(self, stack_dim): @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_to(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedCompositeSpec) + assert isinstance(c, StackedComposite) cdevice = c.to("cuda:0") assert cdevice.device != c.device assert cdevice.device == torch.device("cuda:0") @@ -2799,9 +2691,9 @@ def test_to(self, stack_dim): assert cdevice[index].device == torch.device("cuda:0") def test_clone(self): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2811,9 +2703,9 @@ def test_clone(self): assert cclone[0] == c[0] def test_to_numpy(self): - c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + c1 = Composite(a=Bounded(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Bounded(-1, 1, shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2829,9 +2721,9 @@ def test_to_numpy(self): c.to_numpy(td_fail, safe=True) def test_unsqueeze(self): - c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + c1 = Composite(a=Bounded(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Bounded(-1, 1, shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2984,12 +2876,11 @@ def test_project(self, batch_size): def test_repr(self): c = self._get_heterogeneous_specs() - - expected = f"""LazyStackedCompositeSpec( + expected = f"""StackedComposite( fields={{ - hetero: LazyStackedUnboundedContinuousTensorSpec( + hetero: StackedUnboundedContinuous( shape=torch.Size([3, -1]), device=cpu, dtype=torch.float32, domain=continuous), - shared: BoundedTensorSpec( + shared: BoundedContinuous( shape=torch.Size([3, 32, 32, 3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -2999,7 +2890,7 @@ def test_repr(self): domain=continuous)}}, exclusive_fields={{ 0 -> - lidar: BoundedTensorSpec( + lidar: BoundedContinuous( shape=torch.Size([20]), space=ContinuousBox( low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3007,17 +2898,19 @@ def test_repr(self): device=cpu, dtype=torch.float32, domain=continuous), - individual_0_obs: CompositeSpec( - individual_0_obs_0: UnboundedContinuousTensorSpec( + individual_0_obs: Composite( + individual_0_obs_0: UnboundedContinuous( shape=torch.Size([3, 1]), - space=None, + space=ContinuousBox( + low=Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([3])), 1 -> - lidar: BoundedTensorSpec( + lidar: BoundedContinuous( shape=torch.Size([20]), space=ContinuousBox( low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3025,8 +2918,8 @@ def test_repr(self): device=cpu, dtype=torch.float32, domain=continuous), - individual_1_obs: CompositeSpec( - individual_1_obs_0: BoundedTensorSpec( + individual_1_obs: Composite( + individual_1_obs_0: BoundedContinuous( shape=torch.Size([3, 1, 2]), space=ContinuousBox( low=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3037,10 +2930,12 @@ def test_repr(self): device=cpu, shape=torch.Size([3])), 2 -> - individual_2_obs: CompositeSpec( - individual_1_obs_0: UnboundedContinuousTensorSpec( + individual_2_obs: Composite( + individual_1_obs_0: UnboundedContinuous( shape=torch.Size([3, 1, 2, 3]), - space=None, + space=ContinuousBox( + low=Tensor(shape=torch.Size([3, 1, 2, 3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3, 1, 2, 3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), @@ -3054,11 +2949,11 @@ def test_repr(self): c = c[0:2] del c["individual_0_obs"] del c["individual_1_obs"] - expected = f"""LazyStackedCompositeSpec( + expected = f"""StackedComposite( fields={{ - hetero: LazyStackedUnboundedContinuousTensorSpec( + hetero: StackedUnboundedContinuous( shape=torch.Size([2, -1]), device=cpu, dtype=torch.float32, domain=continuous), - lidar: BoundedTensorSpec( + lidar: BoundedContinuous( shape=torch.Size([2, 20]), space=ContinuousBox( low=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3066,7 +2961,7 @@ def test_repr(self): device=cpu, dtype=torch.float32, domain=continuous), - shared: BoundedTensorSpec( + shared: BoundedContinuous( shape=torch.Size([2, 32, 32, 3]), space=ContinuousBox( low=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3100,7 +2995,7 @@ def test_consolidate_spec(self, batch_size): @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): - shared = UnboundedContinuousTensorSpec( + shared = Unbounded( shape=( *batch_size, 5, @@ -3110,29 +3005,29 @@ def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): ) lazy_spec = torch.stack( [ - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 6, 7)), - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 7, 7)), - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)), - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)), + Unbounded(shape=(*batch_size, 5, 6, 7)), + Unbounded(shape=(*batch_size, 5, 7, 7)), + Unbounded(shape=(*batch_size, 5, 8, 7)), + Unbounded(shape=(*batch_size, 5, 8, 7)), ], dim=len(batch_size), ) spec_list = [ - CompositeSpec( + Composite( { "shared": shared, "lazy_spec": lazy_spec, }, shape=batch_size, ), - CompositeSpec( + Composite( { "shared": shared, }, shape=batch_size, ), - CompositeSpec( + Composite( {}, shape=batch_size, device="cpu", @@ -3168,9 +3063,7 @@ def test_update(self, batch_size, stack_dim=0): spec[1]["individual_1_obs"]["individual_1_obs_0"].space.low.sum() == 0 ) # Only non exclusive keys will be updated - new = torch.stack( - [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], 0 - ) + new = torch.stack([Unbounded(shape=(*batch_size, i)) for i in range(3)], 0) spec2["new"] = new spec.update(spec2) assert spec["new"] == new @@ -3181,7 +3074,7 @@ def test_set_item(self, batch_size, stack_dim): spec = self._get_heterogeneous_specs(batch_size, stack_dim) new = torch.stack( - [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], + [Unbounded(shape=(*batch_size, i)) for i in range(3)], stack_dim, ) spec["new"] = new @@ -3196,15 +3089,15 @@ def test_set_item(self, batch_size, stack_dim): spec[("other", "key")] = new assert spec[("other", "key")] == new - assert isinstance(spec["other"], LazyStackedCompositeSpec) + assert isinstance(spec["other"], StackedComposite) with pytest.raises(RuntimeError, match="key should be a Sequence"): spec[0] = new comp = torch.stack( [ - CompositeSpec( - {"a": UnboundedContinuousTensorSpec(shape=(*batch_size, i))}, + Composite( + {"a": Unbounded(shape=(*batch_size, i))}, shape=batch_size, ) for i in range(3) @@ -3220,10 +3113,10 @@ def test_set_item(self, batch_size, stack_dim): @pytest.mark.parametrize( "spec_class", [ - BinaryDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - CompositeSpec, + Binary, + OneHot, + MultiOneHot, + Composite, ], ) @pytest.mark.parametrize( @@ -3240,13 +3133,13 @@ def test_set_item(self, batch_size, stack_dim): ], # [:,1:2,1] ) def test_invalid_indexing(spec_class, idx): - if spec_class in [BinaryDiscreteTensorSpec, OneHotDiscreteTensorSpec]: + if spec_class in [Binary, OneHot]: spec = spec_class(n=4, shape=[3, 4]) - elif spec_class == MultiDiscreteTensorSpec: + elif spec_class == MultiCategorical: spec = spec_class([2, 2, 2], shape=[3]) - elif spec_class == MultiOneHotDiscreteTensorSpec: + elif spec_class == MultiOneHot: spec = spec_class([4], shape=[3, 4]) - elif spec_class == CompositeSpec: + elif spec_class == Composite: spec = spec_class(k=UnboundedDiscreteTensorSpec(shape=(3, 4)), shape=(3,)) with pytest.raises(IndexError): spec[idx] @@ -3256,13 +3149,13 @@ def test_invalid_indexing(spec_class, idx): @pytest.mark.parametrize( "spec_class", [ - BinaryDiscreteTensorSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, + Binary, + Categorical, + MultiOneHot, + OneHot, + Unbounded, UnboundedDiscreteTensorSpec, - CompositeSpec, + Composite, ], ) def test_valid_indexing(spec_class): @@ -3270,14 +3163,14 @@ def test_valid_indexing(spec_class): args = {"0d": [], "2d": [], "3d": [], "4d": [], "5d": []} kwargs = {} if spec_class in [ - BinaryDiscreteTensorSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Categorical, + OneHot, ]: args = {"0d": [0], "2d": [3], "3d": [4], "4d": [6], "5d": [7]} - elif spec_class == MultiOneHotDiscreteTensorSpec: + elif spec_class == MultiOneHot: args = {"0d": [[0]], "2d": [[3]], "3d": [[4]], "4d": [[6]], "5d": [[7]]} - elif spec_class == MultiDiscreteTensorSpec: + elif spec_class == MultiCategorical: args = { "0d": [[0]], "2d": [[2] * 3], @@ -3285,7 +3178,7 @@ def test_valid_indexing(spec_class): "4d": [[1] * 6], "5d": [[2] * 7], } - elif spec_class == BoundedTensorSpec: + elif spec_class == Bounded: min_max = (-1, -1) args = { "0d": min_max, @@ -3294,17 +3187,17 @@ def test_valid_indexing(spec_class): "4d": min_max, "5d": min_max, } - elif spec_class == CompositeSpec: + elif spec_class == Composite: kwargs = { "k1": UnboundedDiscreteTensorSpec(shape=(5, 3, 4, 6, 7, 8)), - "k2": OneHotDiscreteTensorSpec(n=7, shape=(5, 3, 4, 6, 7)), + "k2": OneHot(n=7, shape=(5, 3, 4, 6, 7)), } spec_0d = spec_class(*args["0d"], **kwargs) if spec_class in [ - UnboundedContinuousTensorSpec, + Unbounded, UnboundedDiscreteTensorSpec, - CompositeSpec, + Composite, ]: spec_0d = spec_class(*args["0d"], shape=[], **kwargs) spec_2d = spec_class(*args["2d"], shape=[5, 3], **kwargs) @@ -3374,10 +3267,10 @@ def test_valid_indexing(spec_class): # Specific tests when specs have non-indexable dimensions if spec_class in [ - BinaryDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, + Binary, + OneHot, + MultiCategorical, + MultiOneHot, ]: # Ellipsis assert spec_0d[None].shape == torch.Size([1, 0]) @@ -3390,7 +3283,6 @@ def test_valid_indexing(spec_class): assert spec_3d[None, 1, ..., None].shape == torch.Size([1, 3, 1, 4]) assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 1, 4, 6]) - # BoundedTensorSpec, DiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, CompositeSpec else: # Integers assert spec_2d[0, 1].shape == torch.Size([]) @@ -3407,7 +3299,7 @@ def test_valid_indexing(spec_class): assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 4, 1, 6]) # Additional tests for composite spec - if spec_class == CompositeSpec: + if spec_class == Composite: assert spec_2d[1]["k1"].shape == torch.Size([3, 4, 6, 7, 8]) assert spec_3d[[1, 2]]["k1"].shape == torch.Size([2, 3, 4, 6, 7, 8]) assert spec_2d[torch.randint(3, (3, 2))]["k1"].shape == torch.Size( @@ -3422,9 +3314,7 @@ def test_valid_indexing(spec_class): def test_composite_contains(): - spec = CompositeSpec( - a=CompositeSpec(b=CompositeSpec(c=UnboundedContinuousTensorSpec())) - ) + spec = Composite(a=Composite(b=Composite(c=Unbounded()))) assert "a" in spec.keys() assert "a" in spec.keys(True) assert ("a",) in spec.keys() @@ -3444,10 +3334,10 @@ def get_all_keys(spec: TensorSpec, include_exclusive: bool): """ keys = set() - if isinstance(spec, LazyStackedCompositeSpec) and include_exclusive: + if isinstance(spec, StackedComposite) and include_exclusive: for t in spec._specs: keys = keys.union(get_all_keys(t, include_exclusive)) - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): for key in spec.keys(): keys.add((key,)) inner_keys = get_all_keys(spec[key], include_exclusive) @@ -3481,7 +3371,7 @@ def _make_mask(self, shape): def _one_hot_spec(self, shape, device, n): shape = torch.Size([*shape, n]) mask = self._make_mask(shape).to(device) - return OneHotDiscreteTensorSpec(n, shape, device, mask=mask) + return OneHot(n, shape, device, mask=mask) def _mult_one_hot_spec(self, shape, device, n): shape = torch.Size([*shape, n + n + 2]) @@ -3492,11 +3382,11 @@ def _mult_one_hot_spec(self, shape, device, n): ], -1, ) - return MultiOneHotDiscreteTensorSpec([n, n + 2], shape, device, mask=mask) + return MultiOneHot([n, n + 2], shape, device, mask=mask) def _discrete_spec(self, shape, device, n): mask = self._make_mask(torch.Size([*shape, n])).to(device) - return DiscreteTensorSpec(n, shape, device, mask=mask) + return Categorical(n, shape, device, mask=mask) def _mult_discrete_spec(self, shape, device, n): shape = torch.Size([*shape, 2]) @@ -3507,7 +3397,7 @@ def _mult_discrete_spec(self, shape, device, n): ], -1, ) - return MultiDiscreteTensorSpec([n, n + 2], shape, device, mask=mask) + return MultiCategorical([n, n + 2], shape, device, mask=mask) def test_equal(self, shape, device, spectype, rand_shape, n=5): shape = torch.Size(shape) @@ -3579,7 +3469,7 @@ def test_project(self, shape, device, spectype, rand_shape, n=5): class TestDynamicSpec: def test_all(self): - spec = UnboundedContinuousTensorSpec((-1, 1, 2)) + spec = Unbounded((-1, 1, 2)) unb = spec assert spec.shape == (-1, 1, 2) x = torch.randn(3, 1, 2) @@ -3593,14 +3483,14 @@ def test_all(self): xunbd = x assert spec.is_in(x) - spec = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1) + spec = Bounded(shape=(-1, 1, 2), low=-1, high=1) bound = spec assert spec.shape == (-1, 1, 2) x = torch.rand((3, 1, 2)) xbound = x assert spec.is_in(x) - spec = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4) + spec = OneHot(shape=(-1, 1, 2, 4), n=4) oneh = spec assert spec.shape == (-1, 1, 2, 4) x = torch.zeros((3, 1, 2, 4), dtype=torch.bool) @@ -3608,14 +3498,14 @@ def test_all(self): xoneh = x assert spec.is_in(x) - spec = DiscreteTensorSpec(shape=(-1, 1, 2), n=4) + spec = Categorical(shape=(-1, 1, 2), n=4) disc = spec assert spec.shape == (-1, 1, 2) x = torch.randint(4, (3, 1, 2)) xdisc = x assert spec.is_in(x) - spec = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4]) + spec = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4]) moneh = spec assert spec.shape == (-1, 1, 2, 7) x = torch.zeros((3, 1, 2, 7), dtype=torch.bool) @@ -3624,7 +3514,7 @@ def test_all(self): xmoneh = x assert spec.is_in(x) - spec = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4]) + spec = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4]) mdisc = spec assert spec.mask is None assert spec.shape == (-1, 1, 2, 2) @@ -3632,7 +3522,7 @@ def test_all(self): xmdisc = x assert spec.is_in(x) - spec = CompositeSpec( + spec = Composite( unb=unb, unbd=unbd, bound=bound, @@ -3659,15 +3549,15 @@ def test_all(self): assert spec.is_in(data) def test_expand(self): - unb = UnboundedContinuousTensorSpec((-1, 1, 2)) + unb = Unbounded((-1, 1, 2)) unbd = UnboundedDiscreteTensorSpec((-1, 1, 2)) - bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1) - oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4) - disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4) - moneh = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4]) - mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4]) + bound = Bounded(shape=(-1, 1, 2), low=-1, high=1) + oneh = OneHot(shape=(-1, 1, 2, 4), n=4) + disc = Categorical(shape=(-1, 1, 2), n=4) + moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4]) + mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4]) - spec = CompositeSpec( + spec = Composite( unb=unb, unbd=unbd, bound=bound, @@ -3689,7 +3579,7 @@ def test_expand(self): class TestNonTensorSpec: def test_sample(self): - nts = NonTensorSpec(shape=(3, 4)) + nts = NonTensor(shape=(3, 4)) assert nts.one((2,)).shape == (2, 3, 4) assert nts.rand((2,)).shape == (2, 3, 4) assert nts.zero((2,)).shape == (2, 3, 4) @@ -3707,26 +3597,24 @@ def test_device_ordinal(): assert _make_ordinal_device(device) is None device = torch.device("cuda") - unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device) + unb = Unbounded((-1, 1, 2), device=device) assert unb.device == torch.device("cuda:0") unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device) assert unbd.device == torch.device("cuda:0") - bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device) + bound = Bounded(shape=(-1, 1, 2), low=-1, high=1, device=device) assert bound.device == torch.device("cuda:0") - oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device) + oneh = OneHot(shape=(-1, 1, 2, 4), n=4, device=device) assert oneh.device == torch.device("cuda:0") - disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device) + disc = Categorical(shape=(-1, 1, 2), n=4, device=device) assert disc.device == torch.device("cuda:0") - moneh = MultiOneHotDiscreteTensorSpec( - shape=(-1, 1, 2, 7), nvec=[3, 4], device=device - ) + moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4], device=device) assert moneh.device == torch.device("cuda:0") - mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device) + mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device) assert mdisc.device == torch.device("cuda:0") - mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device) + mdisc = NonTensor(shape=(-1, 1, 2, 2), device=device) assert mdisc.device == torch.device("cuda:0") - spec = CompositeSpec( + spec = Composite( unb=unb, unbd=unbd, bound=bound, @@ -3740,6 +3628,181 @@ def test_device_ordinal(): assert spec.device == torch.device("cuda:0") +class TestLegacy: + def test_one_hot(self): + with pytest.warns( + DeprecationWarning, + match="The OneHotDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use OneHot instead.", + ): + one_hot = OneHotDiscreteTensorSpec(n=4) + assert isinstance(one_hot, OneHotDiscreteTensorSpec) + assert isinstance(one_hot, OneHot) + assert not isinstance(one_hot, Categorical) + one_hot = OneHot(n=4) + assert isinstance(one_hot, OneHotDiscreteTensorSpec) + assert isinstance(one_hot, OneHot) + assert not isinstance(one_hot, Categorical) + + def test_discrete(self): + with pytest.warns( + DeprecationWarning, + match="The DiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Categorical instead.", + ): + discrete = DiscreteTensorSpec(n=4) + assert isinstance(discrete, DiscreteTensorSpec) + assert isinstance(discrete, Categorical) + assert not isinstance(discrete, OneHot) + discrete = Categorical(n=4) + assert isinstance(discrete, DiscreteTensorSpec) + assert isinstance(discrete, Categorical) + assert not isinstance(discrete, OneHot) + + def test_unbounded(self): + + unbounded_continuous_impl = Unbounded(dtype=torch.float) + assert isinstance(unbounded_continuous_impl, Unbounded) + assert isinstance(unbounded_continuous_impl, UnboundedContinuous) + assert isinstance(unbounded_continuous_impl, UnboundedContinuousTensorSpec) + assert not isinstance(unbounded_continuous_impl, UnboundedDiscreteTensorSpec) + + unbounded_discrete_impl = Unbounded(dtype=torch.int) + assert isinstance(unbounded_discrete_impl, Unbounded) + assert isinstance(unbounded_discrete_impl, UnboundedDiscrete) + assert isinstance(unbounded_discrete_impl, UnboundedDiscreteTensorSpec) + assert not isinstance(unbounded_discrete_impl, UnboundedContinuousTensorSpec) + + with pytest.warns( + DeprecationWarning, + match="The UnboundedContinuousTensorSpec has been deprecated and will be removed in v0.7. Please use Unbounded instead.", + ): + unbounded_continuous = UnboundedContinuousTensorSpec() + assert isinstance(unbounded_continuous, Unbounded) + assert isinstance(unbounded_continuous, UnboundedContinuous) + assert isinstance(unbounded_continuous, UnboundedContinuousTensorSpec) + assert not isinstance(unbounded_continuous, UnboundedDiscreteTensorSpec) + + with warnings.catch_warnings(): + unbounded_continuous = UnboundedContinuous() + + with pytest.warns( + DeprecationWarning, + match="The UnboundedDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Unbounded instead.", + ): + unbounded_discrete = UnboundedDiscreteTensorSpec() + assert isinstance(unbounded_discrete, Unbounded) + assert isinstance(unbounded_discrete, UnboundedDiscrete) + assert isinstance(unbounded_discrete, UnboundedDiscreteTensorSpec) + assert not isinstance(unbounded_discrete, UnboundedContinuousTensorSpec) + + with warnings.catch_warnings(): + unbounded_discrete = UnboundedDiscrete() + + # What if we mess with dtypes? + with pytest.warns(DeprecationWarning): + unbounded_continuous_fake = UnboundedContinuousTensorSpec(dtype=torch.int32) + assert isinstance(unbounded_continuous_fake, Unbounded) + assert not isinstance(unbounded_continuous_fake, UnboundedContinuous) + assert not isinstance(unbounded_continuous_fake, UnboundedContinuousTensorSpec) + assert isinstance(unbounded_continuous_fake, UnboundedDiscrete) + assert isinstance(unbounded_continuous_fake, UnboundedDiscreteTensorSpec) + + with pytest.warns(DeprecationWarning): + unbounded_discrete_fake = UnboundedDiscreteTensorSpec(dtype=torch.float32) + assert isinstance(unbounded_discrete_fake, Unbounded) + assert isinstance(unbounded_discrete_fake, UnboundedContinuous) + assert isinstance(unbounded_discrete_fake, UnboundedContinuousTensorSpec) + assert not isinstance(unbounded_discrete_fake, UnboundedDiscrete) + assert not isinstance(unbounded_discrete_fake, UnboundedDiscreteTensorSpec) + + def test_multi_one_hot(self): + with pytest.warns( + DeprecationWarning, + match="The MultiOneHotDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use MultiOneHot instead.", + ): + one_hot = MultiOneHotDiscreteTensorSpec(nvec=[4, 3]) + assert isinstance(one_hot, MultiOneHotDiscreteTensorSpec) + assert isinstance(one_hot, MultiOneHot) + assert not isinstance(one_hot, MultiCategorical) + one_hot = MultiOneHot(nvec=[4, 3]) + assert isinstance(one_hot, MultiOneHotDiscreteTensorSpec) + assert isinstance(one_hot, MultiOneHot) + assert not isinstance(one_hot, MultiCategorical) + + def test_multi_categorical(self): + with pytest.warns( + DeprecationWarning, + match="The MultiDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use MultiCategorical instead.", + ): + categorical = MultiDiscreteTensorSpec(nvec=[4, 3]) + assert isinstance(categorical, MultiDiscreteTensorSpec) + assert isinstance(categorical, MultiCategorical) + assert not isinstance(categorical, MultiOneHot) + categorical = MultiCategorical(nvec=[4, 3]) + assert isinstance(categorical, MultiDiscreteTensorSpec) + assert isinstance(categorical, MultiCategorical) + assert not isinstance(categorical, MultiOneHot) + + def test_binary(self): + with pytest.warns( + DeprecationWarning, + match="The BinaryDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Binary instead.", + ): + binary = BinaryDiscreteTensorSpec(5) + assert isinstance(binary, BinaryDiscreteTensorSpec) + assert isinstance(binary, Binary) + assert not isinstance(binary, MultiOneHot) + binary = Binary(5) + assert isinstance(binary, BinaryDiscreteTensorSpec) + assert isinstance(binary, Binary) + assert not isinstance(binary, MultiOneHot) + + def test_bounded(self): + with pytest.warns( + DeprecationWarning, + match="The BoundedTensorSpec has been deprecated and will be removed in v0.7. Please use Bounded instead.", + ): + bounded = BoundedTensorSpec(-2, 2, shape=()) + assert isinstance(bounded, BoundedTensorSpec) + assert isinstance(bounded, Bounded) + assert not isinstance(bounded, MultiOneHot) + bounded = Bounded(-2, 2, shape=()) + assert isinstance(bounded, BoundedTensorSpec) + assert isinstance(bounded, Bounded) + assert not isinstance(bounded, MultiOneHot) + + def test_composite(self): + with ( + pytest.warns( + DeprecationWarning, + match="The CompositeSpec has been deprecated and will be removed in v0.7. Please use Composite instead.", + ) + ): + composite = CompositeSpec() + assert isinstance(composite, CompositeSpec) + assert isinstance(composite, Composite) + assert not isinstance(composite, MultiOneHot) + composite = Composite() + assert isinstance(composite, CompositeSpec) + assert isinstance(composite, Composite) + assert not isinstance(composite, MultiOneHot) + + def test_non_tensor(self): + with ( + pytest.warns( + DeprecationWarning, + match="The NonTensorSpec has been deprecated and will be removed in v0.7. Please use NonTensor instead.", + ) + ): + non_tensor = NonTensorSpec() + assert isinstance(non_tensor, NonTensorSpec) + assert isinstance(non_tensor, NonTensor) + assert not isinstance(non_tensor, MultiOneHot) + non_tensor = NonTensor() + assert isinstance(non_tensor, NonTensorSpec) + assert isinstance(non_tensor, NonTensor) + assert not isinstance(non_tensor, MultiOneHot) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 42e0880e6a4..ea177cb9f96 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -11,11 +11,7 @@ from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs import ( CatFrames, Compose, @@ -119,8 +115,8 @@ def forward(self, x): return self.linear_1(x), self.linear_2(x) spec_dict = { - "_": UnboundedContinuousTensorSpec((4,)), - "out_2": UnboundedContinuousTensorSpec((3,)), + "_": Unbounded((4,)), + "out_2": Unbounded((3,)), } # warning due to "_" in spec keys @@ -129,7 +125,7 @@ def forward(self, x): MultiHeadLinear(5, 4, 3), in_keys=["input"], out_keys=["_", "out_2"], - spec=CompositeSpec(**spec_dict), + spec=Composite(**spec_dict), ) @pytest.mark.parametrize("safe", [True, False]) @@ -146,9 +142,9 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) if safe and spec is None: with pytest.raises( @@ -210,9 +206,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) else: raise NotImplementedError @@ -291,9 +287,9 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) kwargs = {} @@ -368,9 +364,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) else: raise NotImplementedError @@ -481,7 +477,7 @@ def test_sequential_partial(self, stack): net3 = nn.Sequential(net3, NormalParamExtractor()) net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) kwargs = {"distribution_class": TanhNormal} @@ -1340,7 +1336,7 @@ def call(data, params): def test_safe_specs(): out_key = ("a", "b") - spec = CompositeSpec(CompositeSpec({out_key: UnboundedContinuousTensorSpec()})) + spec = Composite(Composite({out_key: Unbounded()})) original_spec = spec.clone() mod = SafeModule( module=nn.Linear(3, 1), @@ -1354,9 +1350,7 @@ def test_safe_specs(): def test_actor_critic_specs(): action_key = ("agents", "action") - spec = CompositeSpec( - CompositeSpec({action_key: UnboundedContinuousTensorSpec(shape=(3,))}) - ) + spec = Composite(Composite({action_key: Unbounded(shape=(3,))})) policy_module = TensorDictModule( nn.Linear(3, 1), in_keys=[("agents", "observation")], diff --git a/test/test_transforms.py b/test/test_transforms.py index c38908eba1d..60968ad0975 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -15,7 +15,6 @@ from functools import partial from sys import platform -import numpy as np import pytest import tensordict.tensordict @@ -51,15 +50,15 @@ from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, + Bounded, + Categorical, + Composite, LazyTensorStorage, ReplayBuffer, TensorDictReplayBuffer, TensorSpec, TensorStorage, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs import ( ActionMask, @@ -934,21 +933,17 @@ def test_catframes_transform_observation_spec(self): ) mins = [0, 0.5] maxes = [0.5, 1] - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec( - space_min, space_max, (1, 3, 3), dtype=torch.double - ) + key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double) for key, space_min, space_max in zip(keys, mins, maxes) } ) result = cat_frames.transform_observation_spec(observation_spec) - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec( - space_min, space_max, (1, 3, 3), dtype=torch.double - ) + key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double) for key, space_min, space_max in zip(keys, mins, maxes) } ) @@ -1502,15 +1497,12 @@ def test_r3mnet_transform_observation_spec( ): r3m_net = _R3MNet(in_keys, out_keys, model, del_keys) - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: - exp_ts = CompositeSpec( - { - key: UnboundedContinuousTensorSpec(r3m_net.outdim, device) - for key in out_keys - } + exp_ts = Composite( + {key: Unbounded(r3m_net.outdim, device) for key in out_keys} ) observation_spec_out = r3m_net.transform_observation_spec(observation_spec) @@ -1526,8 +1518,8 @@ def test_r3mnet_transform_observation_spec( for key in in_keys: ts_dict[key] = observation_spec[key] for key in out_keys: - ts_dict[key] = UnboundedContinuousTensorSpec(r3m_net.outdim, device) - exp_ts = CompositeSpec(ts_dict) + ts_dict[key] = Unbounded(r3m_net.outdim, device) + exp_ts = Composite(ts_dict) observation_spec_out = r3m_net.transform_observation_spec(observation_spec) @@ -2020,12 +2012,12 @@ def test_transform_no_env(self, keys, device, out_key): assert tdc.get("dont touch").shape == dont_touch.shape if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 1, (1, 4, 32)) + observation_spec = Bounded(0, 1, (1, 4, 32)) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys} + observation_spec = Composite( + {key: Bounded(0, 1, (1, 4, 32)) for key in keys} ) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec[out_key].shape == torch.Size([1, len(keys) * 4, 32]) @@ -2166,12 +2158,12 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = crop.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = crop.transform_observation_spec(observation_spec) for key in keys: @@ -2373,12 +2365,12 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = cc.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = cc.transform_observation_spec(observation_spec) for key in keys: @@ -2722,18 +2714,15 @@ def test_double2float(self, keys, keys_inv, device): assert td.get("dont touch").dtype != torch.double if len(keys_total) == 1 and len(keys_inv) and keys[0] == "action": - action_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) - input_spec = CompositeSpec( - full_action_spec=CompositeSpec(action=action_spec), full_state_spec=None + action_spec = Bounded(0, 1, (1, 3, 3), dtype=torch.double) + input_spec = Composite( + full_action_spec=Composite(action=action_spec), full_state_spec=None ) action_spec = double2float.transform_input_spec(input_spec) assert action_spec.dtype == torch.float else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 1, (1, 3, 3), dtype=torch.double) for key in keys} ) observation_spec = double2float.transform_observation_spec(observation_spec) for key in keys: @@ -2950,13 +2939,13 @@ class TestExcludeTransform(TransformBase): class EnvWithManyKeys(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - a=UnboundedContinuousTensorSpec(3), - b=UnboundedContinuousTensorSpec(3), - c=UnboundedContinuousTensorSpec(3), + self.observation_spec = Composite( + a=Unbounded(3), + b=Unbounded(3), + c=Unbounded(3), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(2) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(2) def _step( self, @@ -3188,13 +3177,13 @@ class TestSelectTransform(TransformBase): class EnvWithManyKeys(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - a=UnboundedContinuousTensorSpec(3), - b=UnboundedContinuousTensorSpec(3), - c=UnboundedContinuousTensorSpec(3), + self.observation_spec = Composite( + a=Unbounded(3), + b=Unbounded(3), + c=Unbounded(3), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(2) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(2) def _step( self, @@ -3513,15 +3502,12 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) observation_spec = flatten.transform_observation_spec(observation_spec) assert observation_spec.shape[-3] == expected_size else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) - for key in keys - } + observation_spec = Composite( + {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) observation_spec = flatten.transform_observation_spec(observation_spec) for key in keys: @@ -3556,15 +3542,12 @@ def test_transform_compose(self, keys, size, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) observation_spec = flatten.transform_observation_spec(observation_spec) assert observation_spec.shape[-3] == expected_size else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) - for key in keys - } + observation_spec = Composite( + {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) observation_spec = flatten.transform_observation_spec(observation_spec) for key in keys: @@ -3801,12 +3784,12 @@ def test_transform_no_env(self, keys, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = gs.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, 16, 16]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = gs.transform_observation_spec(observation_spec) for key in keys: @@ -3838,12 +3821,12 @@ def test_transform_compose(self, keys, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = gs.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, 16, 16]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = gs.transform_observation_spec(observation_spec) for key in keys: @@ -4443,9 +4426,7 @@ def test_observationnorm( assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec( - 0, 1, (nchannels, 16, 16), device=device - ) + observation_spec = Bounded(0, 1, (nchannels, 16, 16), device=device) observation_spec = on.transform_observation_spec(observation_spec) if standard_normal: assert (observation_spec.space.low == -loc / scale).all() @@ -4455,11 +4436,8 @@ def test_observationnorm( assert (observation_spec.space.high == scale + loc).all() else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 1, (nchannels, 16, 16), device=device) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 1, (nchannels, 16, 16), device=device) for key in keys} ) observation_spec = on.transform_observation_spec(observation_spec) for key in keys: @@ -4480,15 +4458,11 @@ def test_observationnorm_init_stats( ): def make_env(): base_env = ContinuousActionVecMockEnv( - observation_spec=CompositeSpec( - observation=BoundedTensorSpec( - low=1, high=1, shape=torch.Size([size]) - ), - observation_orig=BoundedTensorSpec( - low=1, high=1, shape=torch.Size([size]) - ), + observation_spec=Composite( + observation=Bounded(low=1, high=1, shape=torch.Size([size])), + observation_orig=Bounded(low=1, high=1, shape=torch.Size([size])), ), - action_spec=BoundedTensorSpec(low=1, high=1, shape=torch.Size((size,))), + action_spec=Bounded(low=1, high=1, shape=torch.Size((size,))), seed=0, ) base_env.out_key = "observation" @@ -4669,12 +4643,12 @@ def test_transform_no_env(self, interpolation, keys, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = resize.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = resize.transform_observation_spec(observation_spec) for key in keys: @@ -4706,12 +4680,12 @@ def test_transform_compose(self, interpolation, keys, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = resize.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = resize.transform_observation_spec(observation_spec) for key in keys: @@ -4947,7 +4921,7 @@ def test_reward_scaling(self, batch, scale, loc, keys, device, standard_normal): assert (td.get("dont touch") == td_copy.get("dont touch")).all() if len(keys_total) == 1: - reward_spec = UnboundedContinuousTensorSpec(device=device) + reward_spec = Unbounded(device=device) reward_spec = reward_scaling.transform_reward_spec(reward_spec) assert reward_spec.shape == torch.Size([1]) @@ -5341,24 +5315,24 @@ def test_sum_reward(self, keys, device): # test transform_observation_spec base_env = ContinuousActionVecMockEnv( - reward_spec=UnboundedContinuousTensorSpec(shape=(3, 16, 16)), + reward_spec=Unbounded(shape=(3, 16, 16)), ) transfomed_env = TransformedEnv(base_env, RewardSum()) transformed_observation_spec1 = transfomed_env.observation_spec - assert isinstance(transformed_observation_spec1, CompositeSpec) + assert isinstance(transformed_observation_spec1, Composite) assert "episode_reward" in transformed_observation_spec1.keys() assert "observation" in transformed_observation_spec1.keys() base_env = ContinuousActionVecMockEnv( - reward_spec=UnboundedContinuousTensorSpec(), - observation_spec=CompositeSpec( - observation=UnboundedContinuousTensorSpec(), - some_extra_observation=UnboundedContinuousTensorSpec(), + reward_spec=Unbounded(), + observation_spec=Composite( + observation=Unbounded(), + some_extra_observation=Unbounded(), ), ) transfomed_env = TransformedEnv(base_env, RewardSum()) transformed_observation_spec2 = transfomed_env.observation_spec - assert isinstance(transformed_observation_spec2, CompositeSpec) + assert isinstance(transformed_observation_spec2, Composite) assert "some_extra_observation" in transformed_observation_spec2.keys() assert "episode_reward" in transformed_observation_spec2.keys() @@ -5700,15 +5674,13 @@ def test_transform_no_env( assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec( - -1, 1, (*batch, *size, nchannels, 16, 16) - ) + observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) observation_spec = unsqueeze.transform_observation_spec(observation_spec) assert observation_spec.shape == expected_size else: - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec(-1, 1, (*batch, *size, nchannels, 16, 16)) + key: Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) for key in keys } ) @@ -5862,15 +5834,13 @@ def test_transform_compose( assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec( - -1, 1, (*batch, *size, nchannels, 16, 16) - ) + observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) observation_spec = unsqueeze.transform_observation_spec(observation_spec) assert observation_spec.shape == expected_size else: - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec(-1, 1, (*batch, *size, nchannels, 16, 16)) + key: Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) for key in keys } ) @@ -6466,7 +6436,7 @@ def test_transform_no_env(self, keys, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) + observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( observation_spec ) @@ -6474,11 +6444,8 @@ def test_transform_no_env(self, keys, batch, device): assert (observation_spec.space.low == 0).all() assert (observation_spec.space.high == 1).all() else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( observation_spec @@ -6515,7 +6482,7 @@ def test_transform_compose(self, keys, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) + observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( observation_spec ) @@ -6523,11 +6490,8 @@ def test_transform_compose(self, keys, batch, device): assert (observation_spec.space.low == 0).all() assert (observation_spec.space.high == 1).all() else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( observation_spec @@ -6670,7 +6634,7 @@ class TestTensorDictPrimer(TransformBase): def test_single_trans_env_check(self): env = TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), + TensorDictPrimer(mykey=Unbounded([3])), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6682,14 +6646,10 @@ def test_nested_key_env(self): env = TransformedEnv( env, TensorDictPrimer( - CompositeSpec( + Composite( { - "nested_1": CompositeSpec( - { - "mykey": UnboundedContinuousTensorSpec( - (env.nested_dim_1, 4) - ) - }, + "nested_1": Composite( + {"mykey": Unbounded((env.nested_dim_1, 4))}, shape=(env.nested_dim_1,), ) } @@ -6707,13 +6667,13 @@ def test_nested_key_env(self): assert ("next", "nested_1", "mykey") in env.rollout(3).keys(True, True) def test_transform_no_env(self): - t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])) + t = TensorDictPrimer(mykey=Unbounded([3])) td = TensorDict({"a": torch.zeros(())}, []) t(td) assert "mykey" in td.keys() def test_transform_model(self): - t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])) + t = TensorDictPrimer(mykey=Unbounded([3])) model = nn.Sequential(t, nn.Identity()) td = TensorDict({}, []) model(td) @@ -6722,7 +6682,7 @@ def test_transform_model(self): @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, rbclass): batch_size = (2,) - t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([*batch_size, 3])) + t = TensorDictPrimer(mykey=Unbounded([*batch_size, 3])) rb = rbclass(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({"a": torch.zeros(())}, []) @@ -6734,7 +6694,7 @@ def test_transform_inverse(self): raise pytest.skip("No inverse method for TensorDictPrimer") def test_transform_compose(self): - t = Compose(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))) + t = Compose(TensorDictPrimer(mykey=Unbounded([3]))) td = TensorDict({"a": torch.zeros(())}, []) t(td) assert "mykey" in td.keys() @@ -6743,7 +6703,7 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), + TensorDictPrimer(mykey=Unbounded([3])), ) env = maybe_fork_ParallelEnv(2, make_env) @@ -6761,7 +6721,7 @@ def test_serial_trans_env_check(self): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), + TensorDictPrimer(mykey=Unbounded([3])), ) env = SerialEnv(2, make_env) @@ -6778,7 +6738,7 @@ def make_env(): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=Unbounded([2, 4])), ) try: check_env_specs(env) @@ -6796,7 +6756,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def test_trans_serial_env_check(self, spec_shape): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)), + TensorDictPrimer(mykey=Unbounded(spec_shape)), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6810,8 +6770,8 @@ def test_trans_serial_env_check(self, spec_shape): @pytest.mark.parametrize( "spec", [ - CompositeSpec(b=BoundedTensorSpec(-3, 3, [4])), - BoundedTensorSpec(-3, 3, [4]), + Composite(b=Bounded(-3, 3, [4])), + Bounded(-3, 3, [4]), ], ) @pytest.mark.parametrize("random", [True, False]) @@ -6861,9 +6821,7 @@ def make_env(): else: assert (tensordict_select == value).all() - if isinstance(spec, CompositeSpec) and any( - key != "action" for key in default_keys - ): + if isinstance(spec, Composite) and any(key != "action" for key in default_keys): for key in default_keys: if key in ("action",): continue @@ -6878,7 +6836,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=Unbounded([2, 4])), ) torch.manual_seed(0) env.set_seed(0) @@ -6888,7 +6846,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): 2, lambda: TransformedEnv( GymEnv(CARTPOLE_VERSIONED()), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), + TensorDictPrimer(mykey=Unbounded([4])), ), ) torch.manual_seed(0) @@ -6902,9 +6860,7 @@ def create_tensor(): env = TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer( - mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor - ), + TensorDictPrimer(mykey=Unbounded([3]), default_value=create_tensor), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6913,8 +6869,8 @@ def create_tensor(): def test_dict_default_value(self): # Test with a dict of float default values - key1_spec = UnboundedContinuousTensorSpec([3]) - key2_spec = UnboundedContinuousTensorSpec([3]) + key1_spec = Unbounded([3]) + key2_spec = Unbounded([3]) env = TransformedEnv( ContinuousActionVecMockEnv(), TensorDictPrimer( @@ -6937,8 +6893,8 @@ def test_dict_default_value(self): assert (rollout_td.get(("next", "mykey2")) == 2.0).all() # Test with a dict of callable default values - key1_spec = UnboundedContinuousTensorSpec([3]) - key2_spec = DiscreteTensorSpec(3, dtype=torch.int64) + key1_spec = Unbounded([3]) + key2_spec = Categorical(3, dtype=torch.int64) env = TransformedEnv( ContinuousActionVecMockEnv(), TensorDictPrimer( @@ -7751,13 +7707,11 @@ def test_vipnet_transform_observation_spec( ): vip_net = _VIPNet(in_keys, out_keys, model, del_keys) - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: - exp_ts = CompositeSpec( - {key: UnboundedContinuousTensorSpec(1024, device) for key in out_keys} - ) + exp_ts = Composite({key: Unbounded(1024, device) for key in out_keys}) observation_spec_out = vip_net.transform_observation_spec(observation_spec) @@ -7772,8 +7726,8 @@ def test_vipnet_transform_observation_spec( for key in in_keys: ts_dict[key] = observation_spec[key] for key in out_keys: - ts_dict[key] = UnboundedContinuousTensorSpec(1024, device) - exp_ts = CompositeSpec(ts_dict) + ts_dict[key] = Unbounded(1024, device) + exp_ts = Composite(ts_dict) observation_spec_out = vip_net.transform_observation_spec(observation_spec) @@ -8466,8 +8420,8 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: env.transform.transform_reward_spec(env.base_env.full_reward_spec) def test_independent_obs_specs_from_shared_env(self): - obs_spec = CompositeSpec( - observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,))) + obs_spec = Composite( + observation=Bounded(low=0, high=10, shape=torch.Size((1,))) ) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) t1 = TransformedEnv( @@ -8490,7 +8444,7 @@ def test_independent_obs_specs_from_shared_env(self): assert base_env.observation_spec["observation"].space.high == 10 def test_independent_reward_specs_from_shared_env(self): - reward_spec = UnboundedContinuousTensorSpec() + reward_spec = Unbounded() base_env = ContinuousActionVecMockEnv(reward_spec=reward_spec) t1 = TransformedEnv( base_env, transform=RewardClipping(clamp_min=0, clamp_max=4) @@ -8508,8 +8462,14 @@ def test_independent_reward_specs_from_shared_env(self): assert t2_reward_spec.space.low == -2 assert t2_reward_spec.space.high == 2 - assert base_env.reward_spec.space.low == -np.inf - assert base_env.reward_spec.space.high == np.inf + assert ( + base_env.reward_spec.space.low + == torch.finfo(base_env.reward_spec.dtype).min + ) + assert ( + base_env.reward_spec.space.high + == torch.finfo(base_env.reward_spec.dtype).max + ) def test_allow_done_after_reset(self): base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) @@ -8637,13 +8597,13 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 255, (nchannels, 16, 16)) + observation_spec = Bounded(0, 255, (nchannels, 16, 16)) # StepCounter does not want non composite specs observation_spec = compose[:2].transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(0, 255, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(0, 255, (nchannels, 16, 16)) for key in keys} ) observation_spec = compose.transform_observation_spec(observation_spec) for key in keys: @@ -9600,9 +9560,7 @@ def _make_transform_env(self, out_key, base_env): return Compose( TensorDictPrimer( primers={ - "sample_log_prob": UnboundedContinuousTensorSpec( - shape=base_env.action_spec.shape[:-1] - ) + "sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1]) } ), transform, @@ -9836,20 +9794,18 @@ def test_kl_lstm(self): class TestActionMask(TransformBase): @property def _env_class(self): - from torchrl.data import BinaryDiscreteTensorSpec, DiscreteTensorSpec + from torchrl.data import Binary, Categorical class MaskedEnv(EnvBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.action_spec = DiscreteTensorSpec(4) - self.state_spec = CompositeSpec( - action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool) - ) - self.observation_spec = CompositeSpec( - obs=UnboundedContinuousTensorSpec(3), - action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool), + self.action_spec = Categorical(4) + self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool)) + self.observation_spec = Composite( + obs=Unbounded(3), + action_mask=Binary(4, dtype=torch.bool), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) + self.reward_spec = Unbounded(1) def _reset(self, tensordict): td = self.observation_spec.rand() @@ -10987,27 +10943,25 @@ class TestRemoveEmptySpecs(TransformBase): class DummyEnv(EnvBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), - other=CompositeSpec( - another_other=CompositeSpec(shape=self.batch_size), + self.observation_spec = Composite( + observation=Unbounded((*self.batch_size, 3)), + other=Composite( + another_other=Composite(shape=self.batch_size), shape=self.batch_size, ), shape=self.batch_size, ) - self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) - self.done_spec = DiscreteTensorSpec( - 2, (*self.batch_size, 1), dtype=torch.bool - ) + self.action_spec = Unbounded((*self.batch_size, 3)) + self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone() - self.reward_spec = CompositeSpec( - reward=UnboundedContinuousTensorSpec(*self.batch_size, 1), - other_reward=CompositeSpec(shape=self.batch_size), + self.reward_spec = Composite( + reward=Unbounded(*self.batch_size, 1), + other_reward=Composite(shape=self.batch_size), shape=self.batch_size, ) - self.state_spec = CompositeSpec( - state=CompositeSpec( - sub=CompositeSpec(shape=self.batch_size), shape=self.batch_size + self.state_spec = Composite( + state=Composite( + sub=Composite(shape=self.batch_size), shape=self.batch_size ), shape=self.batch_size, ) @@ -11213,11 +11167,9 @@ class MyEnv(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(3) - ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(1) + self.observation_spec = Composite(observation=Unbounded(3)) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(1) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_batch_size = ( diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index aeee573f8dc..2dd6fcf6c93 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -53,10 +53,10 @@ class submitit_delayed_launcher: ... def main(): ... from torchrl.envs.utils import RandomPolicy from torchrl.envs.libs.gym import GymEnv - ... from torchrl.data import BoundedTensorSpec + ... from torchrl.data import BoundedContinuous ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, - ... policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), + ... policy=RandomPolicy(BoundedContinuous(-1, 1, shape=(1,))), ... launcher="submitit_delayed", ... ) ... for data in collector: diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 3749e6e8cbc..0c1eab4011c 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -58,19 +58,32 @@ TokenizedDatasetLoader, ) from .tensor_specs import ( + Binary, BinaryDiscreteTensorSpec, + Bounded, BoundedTensorSpec, + Categorical, + Composite, CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, LazyStackedCompositeSpec, LazyStackedTensorSpec, + MultiCategorical, MultiDiscreteTensorSpec, + MultiOneHot, MultiOneHotDiscreteTensorSpec, + NonTensor, NonTensorSpec, + OneHot, OneHotDiscreteTensorSpec, + Stacked, + StackedComposite, TensorSpec, + Unbounded, + UnboundedContinuous, UnboundedContinuousTensorSpec, + UnboundedDiscrete, UnboundedDiscreteTensorSpec, ) from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 3cc8d7437c0..d6a49f17113 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -25,12 +25,7 @@ from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Categorical, Composite, Unbounded from torchrl.envs.utils import _classproperty _has_tqdm = importlib.util.find_spec("tqdm", None) is not None @@ -398,24 +393,22 @@ def _proc_spec(spec): if spec is None: return if spec["type"] == "Dict": - return CompositeSpec( + return Composite( {key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()} ) elif spec["type"] == "Box": if all(item == -float("inf") for item in spec["low"]) and all( item == float("inf") for item in spec["high"] ): - return UnboundedContinuousTensorSpec( - spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] - ) - return BoundedTensorSpec( + return Unbounded(spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]]) + return Bounded( shape=spec["shape"], low=torch.as_tensor(spec["low"]), high=torch.as_tensor(spec["high"]), dtype=_DTYPE_DIR[spec["dtype"]], ) elif spec["type"] == "Discrete": - return DiscreteTensorSpec( + return Categorical( spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] ) else: diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c16dd5f9ec6..a81fa3891ad 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -21,6 +21,7 @@ Generic, List, Optional, + overload, Sequence, Tuple, TypeVar, @@ -74,7 +75,7 @@ _DEFAULT_SHAPE = torch.Size((1,)) -DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined." +DEVICE_ERR_MSG = "device of empty Composite is not defined." NOT_IMPLEMENTED_ERROR = NotImplementedError( "method is not currently implemented." " If you are interested in this feature please submit" @@ -199,7 +200,7 @@ def _shape_indexing( Shape of the resulting spec Examples: >>> idx = (2, ..., None) - >>> DiscreteTensorSpec(2, shape=(3, 4))[idx].shape + >>> Categorical(2, shape=(3, 4))[idx].shape torch.Size([4, 1]) >>> _shape_indexing([3, 4], idx) torch.Size([4, 1]) @@ -359,7 +360,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: def __repr__(self): return f"{self.__class__.__name__}()" - def clone(self) -> DiscreteBox: + def clone(self) -> CategoricalBox: return deepcopy(self) @@ -459,19 +460,25 @@ def __eq__(self, other): @dataclass(repr=False) -class DiscreteBox(Box): - """A box of discrete values.""" +class CategoricalBox(Box): + """A box of discrete, categorical values.""" n: int register = invertible_dict() - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> DiscreteBox: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox: return deepcopy(self) def __repr__(self): return f"{self.__class__.__name__}(n={self.n})" +class DiscreteBox(CategoricalBox): + """Deprecated version of :class:`CategoricalBox`.""" + + ... + + @dataclass(repr=False) class BoxList(Box): """A box of discrete values.""" @@ -494,7 +501,7 @@ def __len__(self): @staticmethod def from_nvec(nvec: torch.Tensor): if nvec.ndim == 0: - return DiscreteBox(nvec.item()) + return CategoricalBox(nvec.item()) else: return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)]) @@ -514,14 +521,30 @@ def __repr__(self): @dataclass(repr=False) class TensorSpec: - """Parent class of the tensor meta-data containers for observation, actions and rewards. + """Parent class of the tensor meta-data containers. + + TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class, + or sometimes to simulate simple behaviours by generating random data within a defined space. + + TensorSpecs are primarily used in environments to specify their input/output structure without needing to + execute the environment (or starting it). They can also be used to instantiate shared buffers to pass + data from worker to worker. + + TensorSpecs are dataclasses that always share the following fields: `shape`, `space, `dtype` and `device`. + + As such, TensorSpecs possess some common behavior with :class:`~torch.Tensor` and :class:`~tensordict.TensorDict`: + they can be reshaped, indexed, squeezed, unsqueezed, moved to another device etc. Args: - shape (torch.Size): size of the tensor - space (Box): Box instance describing what kind of values can be - expected - device (torch.device): device of the tensor - dtype (torch.dtype): dtype of the tensor + shape (torch.Size): size of the tensor. The shape includes the batch dimensions as well as the feature + dimension. A negative shape (``-1``) means that the dimension has a variable number of elements. + space (Box): Box instance describing what kind of values can be expected. + device (torch.device): device of the tensor. + dtype (torch.dtype): dtype of the tensor. + + .. note:: A spec can be constructed from a :class:`~tensordict.TensorDict` using the :func:`~torchrl.envs.utils.make_composite_from_td` + function. This function makes a low-assumption educated guess on the specs that may correspond to the input + tensordict and can help to build specs automatically without an in-depth knowledge of the `TensorSpec` API. """ @@ -546,21 +569,35 @@ def decorator(func): @property def device(self) -> torch.device: + """The device of the spec. + + Only :class:`Composite` specs can have a ``None`` device. All leaves must have a non-null device. + """ return self._device @device.setter def device(self, device: torch.device | None) -> None: self._device = _make_ordinal_device(device) - def clear_device_(self): - """A no-op for all leaf specs (which must have a device).""" + def clear_device_(self) -> T: + """A no-op for all leaf specs (which must have a device). + + For :class:`Composite` specs, this method will erase the device. + """ return self def encode( - self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False - ) -> torch.Tensor: + self, + val: np.ndarray | torch.Tensor | TensorDictBase, + *, + ignore_device: bool = False, + ) -> torch.Tensor | TensorDictBase: """Encodes a value given the specified spec, and return the corresponding tensor. + This method is to be used in environments that return a value (eg, a numpy array) that can be + easily mapped to the TorchRL required domain. + If the value is already a tensor, the spec will not change its value and return it as-is. + Args: val (np.ndarray or torch.Tensor): value to be encoded as tensor. @@ -616,8 +653,12 @@ def __setattr__(self, key, value): value = torch.Size(value) super().__setattr__(key, value) - def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: - """Returns the np.ndarray correspondent of an input tensor. + def to_numpy( + self, val: torch.Tensor | TensorDictBase, safe: bool = None + ) -> np.ndarray | dict: + """Returns the ``np.ndarray`` correspondent of an input tensor. + + This is intended to be the inverse operation of :meth:`.encode`. Args: val (torch.Tensor): tensor to be transformed_in to numpy. @@ -626,7 +667,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: - a np.ndarray + a np.ndarray. """ if safe is None: @@ -636,19 +677,31 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return val.detach().cpu().numpy() @property - def ndim(self): + def ndim(self) -> int: + """Number of dimensions of the spec shape. + + Shortcut for ``len(spec.shape)``. + + """ return self.ndimension() - def ndimension(self): + def ndimension(self) -> int: + """Number of dimensions of the spec shape. + + Shortcut for ``len(spec.shape)``. + + """ return len(self.shape) @property - def _safe_shape(self): + def _safe_shape(self) -> torch.Size: """Returns a shape where all heterogeneous values are replaced by one (to be expandable).""" return torch.Size([int(v) if v >= 0 else 1 for v in self.shape]) @abc.abstractmethod - def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: + def index( + self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase + ) -> torch.Tensor | TensorDictBase: """Indexes the input tensor. Args: @@ -661,20 +714,25 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten """ ... + @overload + def expand(self, shape: torch.Size): + ... + @abc.abstractmethod - def expand(self, *shape): - """Returns a new Spec with the extended shape. + def expand(self, *shape: int) -> T: + """Returns a new Spec with the expanded shape. Args: - *shape (tuple or iterable of int): the new shape of the Spec. Must comply with the current shape: + *shape (tuple or iterable of int): the new shape of the Spec. + Must be broadcastable with the current shape: its length must be at least as long as the current shape length, - and its last values must be complient too; ie they can only differ + and its last values must be compliant too; ie they can only differ from it if the current dimension is a singleton. """ ... - def squeeze(self, dim: int | None = None): + def squeeze(self, dim: int | None = None) -> T: """Returns a new Spec with all the dimensions of size ``1`` removed. When ``dim`` is given, a squeeze operation is done only in that dimension. @@ -688,11 +746,18 @@ def squeeze(self, dim: int | None = None): return self return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - def unsqueeze(self, dim: int): + def unsqueeze(self, dim: int) -> T: + """Returns a new Spec with one more singleton dimension (at the position indicated by ``dim``). + + Args: + dim (int or None): the dimension to apply the unsqueeze operation to. + + """ shape = _unsqueezed_shape(self.shape, dim) return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - def make_neg_dim(self, dim): + def make_neg_dim(self, dim: int) -> T: + """Converts a specific dimension to ``-1``.""" if dim < 0: dim = self.ndim + dim if dim < 0 or dim > self.ndim - 1: @@ -701,8 +766,12 @@ def make_neg_dim(self, dim): [s if i != dim else -1 for i, s in enumerate(self.shape)] ) - def reshape(self, *shape): - """Reshapes a tensorspec. + @overload + def reshape(self, shape) -> T: + ... + + def reshape(self, *shape) -> T: + """Reshapes a ``TensorSpec``. Check :func:`~torch.reshape` for more information on this method. @@ -714,23 +783,23 @@ def reshape(self, *shape): view = reshape @abc.abstractmethod - def _reshape(self, shape): + def _reshape(self, shape: torch.Size) -> T: ... - def unflatten(self, dim, sizes): - """Unflattens a tensorspec. + def unflatten(self, dim: int, sizes: Tuple[int]) -> T: + """Unflattens a ``TensorSpec``. Check :func:`~torch.unflatten` for more information on this method. """ return self._unflatten(dim, sizes) - def _unflatten(self, dim, sizes): + def _unflatten(self, dim: int, sizes: Tuple[int]) -> T: shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self._reshape(shape) - def flatten(self, start_dim, end_dim): - """Flattens a tensorspec. + def flatten(self, start_dim: int, end_dim: int) -> T: + """Flattens a ``TensorSpec``. Check :func:`~torch.flatten` for more information on this method. @@ -742,31 +811,39 @@ def _flatten(self, start_dim, end_dim): return self._reshape(shape) @abc.abstractmethod - def _project(self, val: torch.Tensor) -> torch.Tensor: + def _project( + self, val: torch.Tensor | TensorDictBase + ) -> torch.Tensor | TensorDictBase: raise NotImplementedError(type(self)) @abc.abstractmethod - def is_in(self, val: torch.Tensor) -> bool: - """If the value :obj:`val` is in the box defined by the TensorSpec, returns True, otherwise False. + def is_in(self, val: torch.Tensor | TensorDictBase) -> bool: + """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``. + + More precisely, the ``is_in`` methods checks that the value ``val`` is within the limits defined by the ``space`` + attribute (the box), and that the ``dtype``, ``device``, ``shape`` potentially other metadata match those + of the spec. If any of these checks fails, the ``is_in`` method will return ``False``. Args: - val (torch.Tensor): value to be checked + val (torch.Tensor): value to be checked. Returns: - boolean indicating if values belongs to the TensorSpec box + boolean indicating if values belongs to the TensorSpec box. """ ... - def contains(self, item): - """Returns whether a sample is contained within the space defined by the TensorSpec. + def contains(self, item: torch.Tensor | TensorDictBase) -> bool: + """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``. See :meth:`~.is_in` for more information. """ return self.is_in(item) - def project(self, val: torch.Tensor) -> torch.Tensor: - """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic. + def project( + self, val: torch.Tensor | TensorDictBase + ) -> torch.Tensor | TensorDictBase: + """If the input tensor is not in the TensorSpec box, it maps it back to it given some defined heuristic. Args: val (torch.Tensor): tensor to be mapped to the box. @@ -794,10 +871,10 @@ def assert_is_in(self, value: torch.Tensor) -> None: ) def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None: - """Checks the input value dtype against the TensorSpec dtype and raises an exception if they don't match. + """Checks the input value ``dtype`` against the ``TensorSpec`` ``dtype`` and raises an exception if they don't match. Args: - value (torch.Tensor): tensor whose dtype has to be checked + value (torch.Tensor): tensor whose dtype has to be checked. key (str, optional): if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key. @@ -810,8 +887,11 @@ def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None: ) @abc.abstractmethod - def rand(self, shape=None) -> torch.Tensor: - """Returns a random tensor in the space defined by the spec. The sampling will be uniform unless the box is unbounded. + def rand(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Returns a random tensor in the space defined by the spec. + + The sampling will be done uniformly over the space, unless the box is unbounded in which case normal values + will be drawn. Args: shape (torch.Size): shape of the random tensor @@ -820,19 +900,22 @@ def rand(self, shape=None) -> torch.Tensor: a random tensor sampled in the TensorSpec box. """ - raise NotImplementedError + ... - @property - def sample(self): + def sample(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """Returns a random tensor in the space defined by the spec. See :meth:`~.rand` for details. """ - return self.rand + return self.rand(shape=shape) - def zero(self, shape=None) -> torch.Tensor: + def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """Returns a zero-filled tensor in the box. + .. note:: Even though there is no guarantee that ``0`` belongs to the spec domain, + this method will not raise an exception when this condition is violated. + The primary use case of ``zero`` is to generate empty data buffers, not meaningful data. + Args: shape (torch.Size): shape of the zero-tensor @@ -846,21 +929,54 @@ def zero(self, shape=None) -> torch.Tensor: (*shape, *self._safe_shape), dtype=self.dtype, device=self.device ) + def zeros(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Proxy to :meth:`~.zero`.""" + return self.zero(shape=shape) + + def one(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Returns a one-filled tensor in the box. + + .. note:: Even though there is no guarantee that ``1`` belongs to the spec domain, + this method will not raise an exception when this condition is violated. + The primary use case of ``one`` is to generate empty data buffers, not meaningful data. + + Args: + shape (torch.Size): shape of the one-tensor + + Returns: + a one-filled tensor sampled in the TensorSpec box. + + """ + if self.dtype == torch.bool: + return ~self.zero(shape=shape) + return self.zero(shape) + 1 + + def ones(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Proxy to :meth:`~.one`.""" + return self.one(shape=shape) + @abc.abstractmethod def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": - raise NotImplementedError + """Casts a TensorSpec to a device or a dtype. + + Returns the same spec if no change is made. + """ + ... def cpu(self): + """Casts the TensorSpec to 'cpu' device.""" return self.to("cpu") def cuda(self, device=None): + """Casts the TensorSpec to 'cuda' device.""" if device is None: return self.to("cuda") return self.to(f"cuda:{device}") @abc.abstractmethod def clone(self) -> "TensorSpec": - raise NotImplementedError + """Creates a copy of the TensorSpec.""" + ... def __repr__(self): shape_str = indent("shape=" + str(self.shape), " " * 4) @@ -907,7 +1023,7 @@ def __init__(self, *specs: tuple[T, ...], dim: int) -> None: self.dim = len(self.shape) + self.dim def clear_device_(self): - """Clears the device of the CompositeSpec.""" + """Clears the device of the Composite.""" for spec in self._specs: spec.clear_device_() return self @@ -1006,7 +1122,7 @@ def clone(self) -> T: def stack_dim(self): return self.dim - def zero(self, shape=None) -> TensorDictBase: + def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -1017,7 +1133,7 @@ def zero(self, shape=None) -> TensorDictBase: ) return torch.nested.nested_tensor([spec.zero(shape) for spec in self._specs]) - def one(self, shape=None) -> TensorDictBase: + def one(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -1028,7 +1144,7 @@ def one(self, shape=None) -> TensorDictBase: ) return torch.nested.nested_tensor([spec.one(shape) for spec in self._specs]) - def rand(self, shape=None) -> TensorDictBase: + def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -1134,7 +1250,7 @@ def squeeze(self, dim: int = None): ) -class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): +class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec): """A lazy representation of a stack of tensor specs. Stacks tensor-specs together along one dimension. @@ -1143,13 +1259,13 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): Indexing is allowed but only along the stack dimension. - This class is aimed to be used in multi-task and multi-agent settings, where + This class aims at being used in multi-tasks and multi-agent settings, where heterogeneous specs may occur (same semantic but different shape). """ def __eq__(self, other): - if not isinstance(other, LazyStackedTensorSpec): + if not isinstance(other, Stacked): return False if self.device != other.device: raise RuntimeError((self, other)) @@ -1170,8 +1286,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: if safe: if val.shape[self.dim] != len(self._specs): raise ValueError( - "Size of LazyStackedTensorSpec and val differ along the stacking " - "dimension" + "Size of Stacked and val differ along the stacking " "dimension" ) for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): spec.assert_is_in(v) @@ -1183,7 +1298,7 @@ def __repr__(self): dtype_str = "dtype=" + str(self.dtype) domain_str = "domain=" + str(self._specs[0].domain) sub_string = ", ".join([shape_str, device_str, dtype_str, domain_str]) - string = f"LazyStacked{self._specs[0].__class__.__name__}(\n {sub_string})" + string = f"Stacked{self._specs[0].__class__.__name__}(\n {sub_string})" return string @property @@ -1304,7 +1419,7 @@ def encode( @dataclass(repr=False) -class OneHotDiscreteTensorSpec(TensorSpec): +class OneHot(TensorSpec): """A unidimensional, one-hot discrete tensor spec. By default, TorchRL assumes that categorical variables are encoded as @@ -1325,10 +1440,10 @@ class OneHotDiscreteTensorSpec(TensorSpec): Args: n (int): number of possible outcomes. shape (torch.Size, optional): total shape of the sampled tensors. - If provided, the last dimension must match n. + If provided, the last dimension must match ``n``. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. - user_register (bool): experimental feature. If True, every integer + use_register (bool): experimental feature. If ``True``, every integer will be mapped onto a binary vector in the order in which they appear. This feature is designed for environment with no a-priori definition of the number of possible outcomes (e.g. @@ -1338,16 +1453,29 @@ class OneHotDiscreteTensorSpec(TensorSpec): mask (torch.Tensor or None): mask some of the possible outcomes when a sample is taken. See :meth:`~.update_mask` for more information. + Examples: + >>> from torchrl.data.tensor_specs import OneHot + >>> spec = OneHot(5, shape=(2, 5)) + >>> spec.rand() + tensor([[False, True, False, False, False], + [False, True, False, False, False]]) + >>> mask = torch.tensor([ + ... [False, False, False, False, True], + ... [False, False, False, False, True] + ... ]) + >>> spec.update_mask(mask) + >>> spec.rand() + tensor([[False, False, False, False, True], + [False, False, False, False, True]]) + """ shape: torch.Size - space: DiscreteBox + space: CategoricalBox device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, n: int, @@ -1359,7 +1487,7 @@ def __init__( ): dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register - space = DiscreteBox(n) + space = CategoricalBox(n) if shape is None: shape = torch.Size((space.n,)) else: @@ -1387,12 +1515,12 @@ def update_mask(self, mask): mask (torch.Tensor or None): boolean mask. If None, the mask is disabled. Otherwise, the shape of the mask must be expandable to the shape of the spec. ``False`` masks an outcome and ``True`` - leaves the outcome unmasked. If all of the possible outcomes are + leaves the outcome unmasked. If all the possible outcomes are masked, then an error is raised when a sample is taken. Examples: >>> mask = torch.tensor([True, False, False]) - >>> ts = OneHotDiscreteTensorSpec(3, (2, 3,), dtype=torch.int64, mask=mask) + >>> ts = OneHot(3, (2, 3,), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes are masked >>> ts.rand() tensor([[1, 0, 0], @@ -1407,7 +1535,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> OneHot: if dest is None: return self if isinstance(dest, torch.dtype): @@ -1427,7 +1555,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: mask=self.mask.to(dest) if self.mask is not None else None, ) - def clone(self) -> OneHotDiscreteTensorSpec: + def clone(self) -> OneHot: return self.__class__( n=self.space.n, shape=self.shape, @@ -1545,7 +1673,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: @@ -1570,7 +1698,7 @@ def rand(self, shape=None) -> torch.Tensor: def encode( self, val: Union[np.ndarray, torch.Tensor], - space: Optional[DiscreteBox] = None, + space: Optional[CategoricalBox] = None, *, ignore_device: bool = False, ) -> torch.Tensor: @@ -1698,6 +1826,16 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: Returns: The categorical tensor. + + Examples: + >>> one_hot = OneHot(3, shape=(2, 3)) + >>> one_hot_sample = one_hot.rand() + >>> one_hot_sample + tensor([[False, True, False], + [False, True, False]]) + >>> categ_sample = one_hot.to_categorical(one_hot_sample) + >>> categ_sample + tensor([1, 1]) """ if safe is None: safe = _CHECK_SPEC_ENCODE @@ -1705,25 +1843,103 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: self.assert_is_in(val) return val.long().argmax(-1) - def to_categorical_spec(self) -> DiscreteTensorSpec: - """Converts the spec to the equivalent categorical spec.""" - return DiscreteTensorSpec( + def to_categorical_spec(self) -> Categorical: + """Converts the spec to the equivalent categorical spec. + + Examples: + >>> one_hot = OneHot(3, shape=(2, 3)) + >>> one_hot.to_categorical_spec() + Categorical( + shape=torch.Size([2]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.int64, + domain=discrete) + + """ + return Categorical( self.space.n, device=self.device, shape=self.shape[:-1], mask=self.mask, ) + def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """No-op for OneHot.""" + return val + + def to_one_hot_spec(self) -> OneHot: + """No-op for OneHot.""" + return self + + +class _BoundedMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if instance.domain == "continuous": + instance.__class__ = BoundedContinuous + else: + instance.__class__ = BoundedDiscrete + return instance + @dataclass(repr=False) -class BoundedTensorSpec(TensorSpec): - """A bounded continuous tensor spec. +class Bounded(TensorSpec, metaclass=_BoundedMeta): + """A bounded tensor spec. + + ``Bounded`` specs will never appear as such and always be subclassed as :class:`BoundedContinuous` + or :class:`BoundedDiscrete` depending on their dtype (floating points dtypes will result in + :class:`BoundedContinuous` instances, all others in :class:`BoundedDiscrete` instances). Args: low (np.ndarray, torch.Tensor or number): lower bound of the box. high (np.ndarray, torch.Tensor or number): upper bound of the box. + shape (torch.Size): the shape of the ``Bounded`` spec. The shape must be specified. + Inputs ``low``, ``high`` and ``shape`` must be broadcastable. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + domain (str): `"continuous"` or `"discrete"`. Can be used to override the automatic type assignment. + + Examples: + >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.float) + >>> spec + BoundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.int) + >>> spec + BoundedDiscrete( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=discrete) + >>> spec.to(torch.float) + BoundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.int, domain="continuous") + >>> spec + BoundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=continuous) """ @@ -1757,13 +1973,18 @@ def __init__( "Minimum is deprecated since v0.4.0, using low instead.", category=DeprecationWarning, ) - domain = kwargs.pop("domain", "continuous") + domain = kwargs.pop("domain", None) if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") dtype, device = _default_dtype_and_device(dtype, device) if dtype is None: dtype = torch.get_default_dtype() + if domain is None: + if dtype.is_floating_point: + domain = "continuous" + else: + domain = "discrete" if not isinstance(low, torch.Tensor): low = torch.tensor(low, dtype=dtype, device=device) @@ -1778,7 +1999,7 @@ def __init__( if dtype is not None and high.dtype is not dtype: high = high.to(dtype) err_msg = ( - "BoundedTensorSpec requires the shape to be explicitely (via " + "Bounded requires the shape to be explicitely (via " "the shape argument) or implicitely defined (via either the " "minimum or the maximum or both). If the maximum and/or the " "minimum have a non-singleton shape, they must match the " @@ -1954,7 +2175,7 @@ def unbind(self, dim: int = 0): for low, high in zip(low, high) ) - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) a, b = self.space @@ -2037,7 +2258,7 @@ def is_in(self, val: torch.Tensor) -> bool: return False raise err - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2048,7 +2269,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__( + return Bounded( low=self.space.low.to(dest), high=self.space.high.to(dest), shape=self.shape, @@ -2056,7 +2277,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: dtype=dest_dtype, ) - def clone(self) -> BoundedTensorSpec: + def clone(self) -> Bounded: return self.__class__( low=self.space.low.clone(), high=self.space.high.clone(), @@ -2083,6 +2304,45 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) +class BoundedContinuous(Bounded, metaclass=_BoundedMeta): + """A specialized version of :class:`torchrl.data.Bounded` with continuous space.""" + + def __init__( + self, + low: Union[float, torch.Tensor, np.ndarray] = None, + high: Union[float, torch.Tensor, np.ndarray] = None, + shape: Optional[Union[torch.Size, int]] = None, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[torch.dtype, str]] = None, + domain: str = "continuous", + ): + super().__init__( + low=low, high=high, shape=shape, device=device, dtype=dtype, domain=domain + ) + + +class BoundedDiscrete(Bounded, metaclass=_BoundedMeta): + """A specialized version of :class:`torchrl.data.Bounded` with discrete space.""" + + def __init__( + self, + low: Union[float, torch.Tensor, np.ndarray] = None, + high: Union[float, torch.Tensor, np.ndarray] = None, + shape: Optional[Union[torch.Size, int]] = None, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[torch.dtype, str]] = None, + domain: str = "discrete", + ): + super().__init__( + low=low, + high=high, + shape=shape, + device=device, + dtype=dtype, + domain=domain, + ) + + def _is_nested_list(index, notuple=False): if not notuple and isinstance(index, tuple): for idx in index: @@ -2097,8 +2357,14 @@ def _is_nested_list(index, notuple=False): return False -class NonTensorSpec(TensorSpec): - """A spec for non-tensor data.""" +class NonTensor(TensorSpec): + """A spec for non-tensor data. + + This spec has a shae, device and dtype like :class:`~tensordict.NonTensorData`. + + :meth:`.rand` will return a :class:`~tensordict.NonTensorData` object with `None` data value. + (same will go for :meth:`.zero` and :meth:`.one`). + """ def __init__( self, @@ -2116,7 +2382,7 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2129,7 +2395,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: return self return self.__class__(shape=self.shape, device=dest_device, dtype=None) - def clone(self) -> NonTensorSpec: + def clone(self) -> NonTensor: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) def rand(self, shape=None): @@ -2212,17 +2478,76 @@ def unbind(self, dim: int = 0): ) +class _UnboundedMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if instance.domain == "continuous": + instance.__class__ = UnboundedContinuous + else: + instance.__class__ = UnboundedDiscrete + return instance + + @dataclass(repr=False) -class UnboundedContinuousTensorSpec(TensorSpec): - """An unbounded continuous tensor spec. +class Unbounded(TensorSpec, metaclass=_UnboundedMeta): + """An unbounded tensor spec. + + ``Unbounded`` specs will never appear as such and always be subclassed as :class:`UnboundedContinuous` + or :class:`UnboundedDiscrete` depending on their dtype (floating points dtypes will result in + :class:`UnboundedContinuous` instances, all others in :class:`UnboundedDiscrete` instances). + + Although it is not properly limited above and below, this class still has a :attr:`Box` space that encodes + the maximum and minimum value that the dtype accepts. Args: + shape (torch.Size): the shape of the ``Bounded`` spec. The shape must be specified. + Inputs ``low``, ``high`` and ``shape`` must be broadcastable. device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors - (should be an floating point dtype such as float, double etc.) - """ + dtype (str or torch.dtype, optional): dtype of the tensors. + domain (str): `"continuous"` or `"discrete"`. Can be used to override the automatic type assignment. - # SPEC_HANDLED_FUNCTIONS = {} + Examples: + >>> spec = Unbounded(shape=(), dtype=torch.float) + >>> spec + UnboundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Unbounded(shape=(), dtype=torch.int) + >>> spec + UnboundedDiscrete( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=discrete) + >>> spec.to(torch.float) + UnboundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Unbounded(shape=(), dtype=torch.int, domain="continuous") + >>> spec + UnboundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=continuous) + + """ def __init__( self, @@ -2235,23 +2560,34 @@ def __init__( shape = torch.Size([shape]) dtype, device = _default_dtype_and_device(dtype, device) - box = ( - ContinuousBox( - torch.as_tensor(-np.inf, device=device).expand(shape), - torch.as_tensor(np.inf, device=device).expand(shape), - ) - if shape == _DEFAULT_SHAPE - else None + if dtype == torch.bool: + min_value = False + max_value = True + default_domain = "discrete" + else: + if dtype.is_floating_point: + min_value = torch.finfo(dtype).min + max_value = torch.finfo(dtype).max + default_domain = "continuous" + else: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max + default_domain = "discrete" + box = ContinuousBox( + torch.full( + _remove_neg_shapes(shape), min_value, device=device, dtype=dtype + ), + torch.full( + _remove_neg_shapes(shape), max_value, device=device, dtype=dtype + ), ) - default_domain = "continuous" if dtype.is_floating_point else "discrete" + domain = kwargs.pop("domain", default_domain) super().__init__( shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs ) - def to( - self, dest: Union[torch.dtype, DEVICE_TYPING] - ) -> UnboundedContinuousTensorSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Unbounded: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2262,12 +2598,12 @@ def to( dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) + return Unbounded(shape=self.shape, device=dest_device, dtype=dest_dtype) - def clone(self) -> UnboundedContinuousTensorSpec: + def clone(self) -> Unbounded: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) shape = [*shape, *self.shape] @@ -2333,21 +2669,12 @@ def unbind(self, dim: int = 0): def __eq__(self, other): # those specs are equivalent to a discrete spec - if isinstance(other, UnboundedDiscreteTensorSpec): - return ( - UnboundedDiscreteTensorSpec( - shape=self.shape, - device=self.device, - dtype=self.dtype, - ) - == other - ) - if isinstance(other, BoundedTensorSpec): + if isinstance(other, Bounded): minval, maxval = _minmax_dtype(self.dtype) minval = torch.as_tensor(minval).to(self.device, self.dtype) maxval = torch.as_tensor(maxval).to(self.device, self.dtype) return ( - BoundedTensorSpec( + Bounded( shape=self.shape, high=maxval, low=minval, @@ -2357,185 +2684,43 @@ def __eq__(self, other): ) == other ) + elif isinstance(other, Unbounded): + if self.dtype != other.dtype: + return False + if self.shape != other.shape: + return False + if self.device != other.device: + return False + return True return super().__eq__(other) -@dataclass(repr=False) -class UnboundedDiscreteTensorSpec(TensorSpec): - """An unbounded discrete tensor spec. +class UnboundedContinuous(Unbounded): + """A specialized version of :class:`torchrl.data.Unbounded` with continuous space.""" - Args: - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors - (should be an integer dtype such as long, uint8 etc.) - """ + ... - # SPEC_HANDLED_FUNCTIONS = {} + +class UnboundedDiscrete(Unbounded): + """A specialized version of :class:`torchrl.data.Unbounded` with discrete space.""" def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.int64, + **kwargs, ): - if isinstance(shape, int): - shape = torch.Size([shape]) - - dtype, device = _default_dtype_and_device(dtype, device) - if dtype == torch.bool: - min_value = False - max_value = True - else: - if dtype.is_floating_point: - min_value = torch.finfo(dtype).min - max_value = torch.finfo(dtype).max - else: - min_value = torch.iinfo(dtype).min - max_value = torch.iinfo(dtype).max - space = ContinuousBox( - torch.full(_remove_neg_shapes(shape), min_value, device=device), - torch.full(_remove_neg_shapes(shape), max_value, device=device), - ) - - super().__init__( - shape=shape, - space=space, - device=device, - dtype=dtype, - domain="discrete", - ) - - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: - if isinstance(dest, torch.dtype): - dest_dtype = dest - dest_device = self.device - elif dest is None: - return self - else: - dest_dtype = self.dtype - dest_device = torch.device(dest) - if dest_device == self.device and dest_dtype == self.dtype: - return self - return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) - - def clone(self) -> UnboundedDiscreteTensorSpec: - return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) - - def rand(self, shape=None) -> torch.Tensor: - if shape is None: - shape = torch.Size([]) - interval = self.space.high - self.space.low - r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) - r = r * interval - r = self.space.low + r - r = r.to(self.dtype) - return r.to(self.device) - - def is_in(self, val: torch.Tensor) -> bool: - shape = torch.broadcast_shapes(self._safe_shape, val.shape) - return val.shape == shape and val.dtype == self.dtype - - def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): - shape = shape[0] - if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): - raise ValueError( - f"The last {self.ndim} of the expanded shape {shape} must match the" - f"shape of the {self.__class__.__name__} spec in expand()." - ) - return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - - def _reshape(self, shape): - return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - - def _unflatten(self, dim, sizes): - shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape - return self.__class__( - shape=shape, - device=self.device, - dtype=self.dtype, - ) - - def __getitem__(self, idx: SHAPE_INDEX_TYPING): - """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) - return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) - - def unbind(self, dim: int = 0): - orig_dim = dim - if dim < 0: - dim = len(self.shape) + dim - if dim < 0: - raise ValueError( - f"Cannot unbind along dim {orig_dim} with shape {self.shape}." - ) - shape = tuple(s for i, s in enumerate(self.shape) if i != dim) - return tuple( - self.__class__( - shape=shape, - device=self.device, - dtype=self.dtype, - ) - for i in range(self.shape[dim]) - ) - - def __eq__(self, other): - # those specs are equivalent to a discrete spec - if isinstance(other, UnboundedContinuousTensorSpec): - return ( - UnboundedContinuousTensorSpec( - shape=self.shape, - device=self.device, - dtype=self.dtype, - domain=self.domain, - ) - == other - ) - if isinstance(other, BoundedTensorSpec): - return ( - BoundedTensorSpec( - shape=self.shape, - high=self.space.high, - low=self.space.low, - dtype=self.dtype, - device=self.device, - domain=self.domain, - ) - == other - ) - return super().__eq__(other) - - def __ne__(self, other): - # those specs are equivalent to a discrete spec - if isinstance(other, UnboundedContinuousTensorSpec): - return ( - UnboundedContinuousTensorSpec( - shape=self.shape, - device=self.device, - dtype=self.dtype, - domain=self.domain, - ) - != other - ) - if isinstance(other, BoundedTensorSpec): - return ( - BoundedTensorSpec( - shape=self.shape, - high=self.space.high, - low=self.space.low, - dtype=self.dtype, - device=self.device, - domain=self.domain, - ) - != other - ) - return super().__ne__(other) + super().__init__(shape=shape, device=device, dtype=dtype, **kwargs) @dataclass(repr=False) -class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): +class MultiOneHot(OneHot): """A concatenation of one-hot discrete tensor spec. + This class can be used when a single tensor must carry information about multiple one-hot encoded + values. + The last dimension of the shape (domain of the tensor elements) cannot be indexed. Args: @@ -2550,20 +2735,22 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): sample is taken. See :meth:`~.update_mask` for more information. Examples: - >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) - >>> ts.is_in(torch.tensor([0,0,1, - ... 0,1, - ... 1,0,0])) + >>> ts = MultiOneHot((3,2,3)) + >>> ts.rand() + tensor([ True, False, False, True, False, False, False, True]) + >>> ts.is_in(torch.tensor([ + ... 0, 0, 1, + ... 0, 1, + ... 1, 0, 0], dtype=torch.bool)) True - >>> ts.is_in(torch.tensor([1,0,1, - ... 0,1, - ... 1,0,0])) # False + >>> ts.is_in(torch.tensor([ + ... 1, 0, 1, + ... 0, 1, + ... 1, 0, 0], dtype=torch.bool)) False """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, nvec: Sequence[int], @@ -2584,9 +2771,9 @@ def __init__( f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. " f"Got sum(nvec)={sum(nvec)} and shape={shape}." ) - space = BoxList([DiscreteBox(n) for n in nvec]) + space = BoxList([CategoricalBox(n) for n in nvec]) self.use_register = use_register - super(OneHotDiscreteTensorSpec, self).__init__( + super(OneHot, self).__init__( shape, space, device, @@ -2610,7 +2797,7 @@ def update_mask(self, mask): Examples: >>> mask = torch.tensor([True, False, False, ... True, True]) - >>> ts = MultiOneHotDiscreteTensorSpec((3, 2), (2, 5), dtype=torch.int64, mask=mask) + >>> ts = MultiOneHot((3, 2), (2, 5), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes for the first >>> # one-hot group are masked, but neither of the two possible >>> # outcomes for the second one-hot group are masked. @@ -2627,7 +2814,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiOneHot: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2638,7 +2825,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__( + return MultiOneHot( nvec=deepcopy(self.nvec), shape=self.shape, device=dest_device, @@ -2646,7 +2833,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: mask=self.mask.to(dest) if self.mask is not None else None, ) - def clone(self) -> MultiOneHotDiscreteTensorSpec: + def clone(self) -> MultiOneHot: return self.__class__( nvec=deepcopy(self.nvec), shape=self.shape, @@ -2731,9 +2918,7 @@ def encode( f"value {v} is greater than the allowed max {space.n}" ) x.append( - super(MultiOneHotDiscreteTensorSpec, self).encode( - v, space, ignore_device=ignore_device - ) + super(MultiOneHot, self).encode(v, space, ignore_device=ignore_device) ) return torch.cat(x, -1).reshape(self.shape) @@ -2785,7 +2970,7 @@ def _split_self(self): n = space.n shape = self.shape[:-1] + (n,) result.append( - OneHotDiscreteTensorSpec( + OneHot( n=n, shape=shape, device=device, @@ -2807,6 +2992,16 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: Returns: The categorical tensor. + + Examples: + >>> mone_hot = MultiOneHot((2, 3, 4)) + >>> onehot_sample = mone_hot.rand() + >>> onehot_sample + tensor([False, True, False, False, True, False, True, False, False]) + >>> categ_sample = mone_hot.to_categorical(onehot_sample) + >>> categ_sample + tensor([1, 2, 1]) + """ if safe is None: safe = _CHECK_SPEC_ENCODE @@ -2815,15 +3010,36 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: vals = self._split(val) return torch.stack([val.long().argmax(-1) for val in vals], -1) - def to_categorical_spec(self) -> MultiDiscreteTensorSpec: - """Converts the spec to the equivalent categorical spec.""" - return MultiDiscreteTensorSpec( + def to_categorical_spec(self) -> MultiCategorical: + """Converts the spec to the equivalent categorical spec. + + Examples: + >>> mone_hot = MultiOneHot((2, 3, 4)) + >>> categ = mone_hot.to_categorical_spec() + >>> categ + MultiCategorical( + shape=torch.Size([3]), + space=BoxList(boxes=[CategoricalBox(n=2), CategoricalBox(n=3), CategoricalBox(n=4)]), + device=cpu, + dtype=torch.int64, + domain=discrete) + + """ + return MultiCategorical( [_space.n for _space in self.space], device=self.device, shape=[*self.shape[:-1], len(self.space)], mask=self.mask, ) + def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """No-op for MultiOneHot.""" + return val + + def to_one_hot_spec(self) -> OneHot: + """No-op for MultiOneHot.""" + return self + def expand(self, *shape): nvecs = [space.n for space in self.space] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2932,22 +3148,15 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) -class DiscreteTensorSpec(TensorSpec): +class Categorical(TensorSpec): """A discrete tensor spec. - An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of - using multiplication, categorical variables perform indexing which can speed up + An alternative to :class:`OneHot` for categorical variables in TorchRL. + Categorical variables perform indexing insted of masking, which can speed-up computation and reduce memory cost for large categorical variables. - The last dimension of the spec (length n of the binary vector) cannot be indexed - Example: - >>> batch, size = 3, 4 - >>> action_value = torch.arange(batch*size) - >>> action_value = action_value.view(batch, size).to(torch.float) - >>> action = torch.argmax(action_value, dim=-1).to(torch.long) - >>> chosen_action_value = action_value[range(batch), action] - >>> print(chosen_action_value) - tensor([ 3., 7., 11.]) + The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is + desired for the training dimension, one should specify it explicitly. Args: n (int): number of possible outcomes. @@ -2957,10 +3166,32 @@ class DiscreteTensorSpec(TensorSpec): mask (torch.Tensor or None): mask some of the possible outcomes when a sample is taken. See :meth:`~.update_mask` for more information. + Examples: + >>> categ = Categorical(3) + >>> categ + Categorical( + shape=torch.Size([]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.int64, + domain=discrete) + >>> categ.rand() + tensor(2) + >>> categ = Categorical(3, shape=(1,)) + >>> categ + Categorical( + shape=torch.Size([1]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.int64, + domain=discrete) + >>> categ.rand() + tensor([1]) + """ shape: torch.Size - space: DiscreteBox + space: CategoricalBox device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -2978,7 +3209,7 @@ def __init__( if shape is None: shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) - space = DiscreteBox(n) + space = CategoricalBox(n) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" ) @@ -3003,7 +3234,7 @@ def update_mask(self, mask): Examples: >>> mask = torch.tensor([True, False, True]) - >>> ts = DiscreteTensorSpec(3, (10,), dtype=torch.int64, mask=mask) + >>> ts = Categorical(3, (10,), dtype=torch.int64, mask=mask) >>> # One of the three possible outcomes is masked >>> ts.rand() tensor([0, 2, 2, 0, 2, 0, 2, 2, 0, 2]) @@ -3017,7 +3248,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) if self.mask is None: @@ -3115,6 +3346,15 @@ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: Returns: The one-hot encoded tensor. + + Examples: + >>> categ = Categorical(3) + >>> categ_sample = categ.zero() + >>> categ_sample + tensor(0) + >>> onehot_sample = categ.to_one_hot(categ_sample) + >>> onehot_sample + tensor([ True, False, False]) """ if safe is None: safe = _CHECK_SPEC_ENCODE @@ -3122,15 +3362,35 @@ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: self.assert_is_in(val) return torch.nn.functional.one_hot(val, self.space.n).bool() - def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: - """Converts the spec to the equivalent one-hot spec.""" + def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """No-op for categorical.""" + return val + + def to_one_hot_spec(self) -> OneHot: + """Converts the spec to the equivalent one-hot spec. + + Examples: + >>> categ = Categorical(3) + >>> categ.to_one_hot_spec() + OneHot( + shape=torch.Size([3]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.bool, + domain=discrete) + + """ shape = [*self.shape, self.space.n] - return OneHotDiscreteTensorSpec( + return OneHot( n=self.space.n, shape=shape, device=self.device, ) + def to_categorical_spec(self) -> Categorical: + """No-op for categorical.""" + return self + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -3208,7 +3468,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Categorical: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3223,7 +3483,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype ) - def clone(self) -> DiscreteTensorSpec: + def clone(self) -> Categorical: return self.__class__( n=self.space.n, shape=self.shape, @@ -3234,28 +3494,55 @@ def clone(self) -> DiscreteTensorSpec: @dataclass(repr=False) -class BinaryDiscreteTensorSpec(DiscreteTensorSpec): +class Binary(Categorical): """A binary discrete tensor spec. + A binary tensor spec encodes tensors of arbitrary size where the values are either 0 or 1 (or ``True`` or ``False`` + if the dtype it ``torch.bool``). + + Unlike :class:`OneHot`, `Binary` can have more than one non-null element along the last dimension. + Args: - n (int): length of the binary vector. + n (int): length of the binary vector. If provided along with ``shape``, ``shape[-1]`` must match ``n``. + If not provided, ``shape`` must be passed. + + .. warning:: the ``n`` argument from ``Binary`` must not be confused with the ``n`` argument from :class:`Categorical` + or :class:`OneHot` which denotes the maximum nmber of elements that can be sampled. + For clarity, use ``shape`` instead. + shape (torch.Size, optional): total shape of the sampled tensors. - If provided, the last dimension must match n. + If provided, the last dimension must match ``n``. device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long. + dtype (str or torch.dtype, optional): dtype of the tensors. + Defaults to ``torch.int8``. Examples: - >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool) - >>> print(spec.zero()) + >>> torch.manual_seed(0) + >>> spec = Binary(n=4, shape=(2, 4)) + >>> print(spec.rand()) + tensor([[0, 1, 1, 0], + [1, 1, 1, 1]], dtype=torch.int8) + >>> spec = Binary(shape=(2, 4)) + >>> print(spec.rand()) + tensor([[1, 1, 1, 0], + [0, 1, 0, 0]], dtype=torch.int8) + >>> spec = Binary(n=4) + >>> print(spec.rand()) + tensor([0, 0, 0, 1], dtype=torch.int8) + """ def __init__( self, - n: int, + n: int | None = None, shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Union[str, torch.dtype] = torch.int8, ): + if n is None and not shape: + raise TypeError("Must provide either n or shape.") + if n is None: + n = shape[-1] if shape is None or not len(shape): shape = torch.Size((n,)) else: @@ -3327,7 +3614,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Binary: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3342,7 +3629,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype ) - def clone(self) -> BinaryDiscreteTensorSpec: + def clone(self) -> Binary: return self.__class__( n=self.shape[-1], shape=self.shape, @@ -3364,8 +3651,8 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) def __eq__(self, other): - if not isinstance(other, BinaryDiscreteTensorSpec): - if isinstance(other, DiscreteTensorSpec): + if not isinstance(other, Binary): + if isinstance(other, Categorical): return ( other.n == 2 and other.device == self.device @@ -3377,7 +3664,7 @@ def __eq__(self, other): @dataclass(repr=False) -class MultiDiscreteTensorSpec(DiscreteTensorSpec): +class MultiCategorical(Categorical): """A concatenation of discrete tensor spec. Args: @@ -3394,15 +3681,13 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): sample is taken. See :meth:`~.update_mask` for more information. Examples: - >>> ts = MultiDiscreteTensorSpec((3, 2, 3)) + >>> ts = MultiCategorical((3, 2, 3)) >>> ts.is_in(torch.tensor([2, 0, 1])) True - >>> ts.is_in(torch.tensor([2, 2, 1])) + >>> ts.is_in(torch.tensor([2, 10, 1])) False """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, nvec: Union[Sequence[int], torch.Tensor, int], @@ -3431,7 +3716,7 @@ def __init__( self.nvec = self.nvec.expand(_remove_neg_shapes(shape)) space = BoxList.from_nvec(self.nvec) - super(DiscreteTensorSpec, self).__init__( + super(Categorical, self).__init__( shape, space, device, dtype, domain="discrete" ) self.update_mask(mask) @@ -3451,9 +3736,10 @@ def update_mask(self, mask): sample is taken. Examples: + >>> torch.manual_seed(0) >>> mask = torch.tensor([False, False, True, ... True, True]) - >>> ts = MultiDiscreteTensorSpec((3, 2), (5, 2,), dtype=torch.int64, mask=mask) + >>> ts = MultiCategorical((3, 2), (5, 2,), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes for the first >>> # group are masked, but neither of the two possible >>> # outcomes for the second group are masked. @@ -3462,7 +3748,7 @@ def update_mask(self, mask): [2, 0], [2, 1], [2, 1], - [2, 0]]) + [2, 1]]) """ if mask is not None: try: @@ -3473,7 +3759,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiCategorical: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3512,7 +3798,7 @@ def __eq__(self, other): and mask_equal ) - def clone(self) -> MultiDiscreteTensorSpec: + def clone(self) -> MultiCategorical: return self.__class__( nvec=self.nvec.clone(), shape=None, @@ -3574,9 +3860,7 @@ def _split_self(self): for n, _mask in zip(nvec, mask): shape = self.shape[:-1] result.append( - DiscreteTensorSpec( - n=n, shape=shape, device=device, dtype=dtype, mask=_mask - ) + Categorical(n=n, shape=shape, device=device, dtype=dtype, mask=_mask) ) return result @@ -3629,7 +3913,7 @@ def is_in(self, val: torch.Tensor) -> bool: def to_one_hot( self, val: torch.Tensor, safe: bool = None - ) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]: + ) -> Union[MultiOneHot, torch.Tensor]: """Encodes a discrete tensor from the spec domain into its one-hot correspondent. Args: @@ -3653,16 +3937,24 @@ def to_one_hot( -1, ).to(self.device) - def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: + def to_one_hot_spec(self) -> MultiOneHot: """Converts the spec to the equivalent one-hot spec.""" nvec = [_space.n for _space in self.space] - return MultiOneHotDiscreteTensorSpec( + return MultiOneHot( nvec, device=self.device, shape=[*self.shape[:-1], sum(nvec)], mask=self.mask, ) + def to_categorical(self, val: torch.Tensor, safe: bool = None) -> MultiCategorical: + """Not op for MultiCategorical.""" + return val + + def to_categorical_spec(self) -> MultiCategorical: + """Not op for MultiCategorical.""" + return self + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -3779,12 +4071,16 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) -class CompositeSpec(TensorSpec): +class Composite(TensorSpec): """A composition of TensorSpecs. + If a ``TensorSpec`` is the set-description of Tensor category, the ``Composite`` class is akin to + the :class:`~tensordict.TensorDict` class. Like :class:`~tensordict.TensorDict`, it has a ``shape`` (akin to the + ``TensorDict``'s ``batch_size``) and an optional ``device``. + Args: *args: if an unnamed argument is passed, it must be a dictionary with keys - matching the expected keys to be found in the :obj:`CompositeSpec` object. + matching the expected keys to be found in the :obj:`Composite` object. This is useful to build nested CompositeSpecs with tuple indices. **kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs to be stored. Values can be None, in which case is_in will be assumed @@ -3801,53 +4097,59 @@ class CompositeSpec(TensorSpec): to the batch-size of the corresponding tensordicts. Examples: - >>> pixels_spec = BoundedTensorSpec( - ... torch.zeros(3,32,32), - ... torch.ones(3, 32, 32)) - >>> observation_vector_spec = BoundedTensorSpec(torch.zeros(33), - ... torch.ones(33)) - >>> composite_spec = CompositeSpec( + >>> pixels_spec = Bounded( + ... low=torch.zeros(4, 3, 32, 32), + ... high=torch.ones(4, 3, 32, 32), + ... dtype=torch.uint8 + ... ) + >>> observation_vector_spec = Bounded( + ... low=torch.zeros(4, 33), + ... high=torch.ones(4, 33), + ... dtype=torch.float) + >>> composite_spec = Composite( ... pixels=pixels_spec, - ... observation_vector=observation_vector_spec) - >>> td = TensorDict({"pixels": torch.rand(10,3,32,32), - ... "observation_vector": torch.rand(10,33)}, batch_size=[10]) - >>> print("td (rand) is within bounds: ", composite_spec.is_in(td)) - td (rand) is within bounds: True - >>> td = TensorDict({"pixels": torch.randn(10,3,32,32), - ... "observation_vector": torch.randn(10,33)}, batch_size=[10]) - >>> print("td (randn) is within bounds: ", composite_spec.is_in(td)) - td (randn) is within bounds: False - >>> td_project = composite_spec.project(td) - >>> print("td modification done in place: ", td_project is td) - td modification done in place: True - >>> print("check td is within bounds after projection: ", - ... composite_spec.is_in(td_project)) - check td is within bounds after projection: True - >>> print("random td: ", composite_spec.rand([3,])) - random td: TensorDict( + ... observation_vector=observation_vector_spec, + ... shape=(4,) + ... ) + >>> composite_spec + Composite( + pixels: BoundedDiscrete( + shape=torch.Size([4, 3, 32, 32]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, contiguous=True), + high=Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, contiguous=True)), + device=cpu, + dtype=torch.uint8, + domain=discrete), + observation_vector: BoundedContinuous( + shape=torch.Size([4, 33]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + device=None, + shape=torch.Size([4])) + >>> td = composite_spec.rand() + >>> td + TensorDict( fields={ - observation_vector: Tensor(torch.Size([3, 33]), dtype=torch.float32), - pixels: Tensor(torch.Size([3, 3, 32, 32]), dtype=torch.float32)}, - batch_size=torch.Size([3]), + observation_vector: Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, is_shared=False), + pixels: Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([4]), device=None, is_shared=False) - - Examples: >>> # we can build a nested composite spec using unnamed arguments - >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) - CompositeSpec( - a: CompositeSpec( + >>> print(Composite({("a", "b"): None, ("a", "c"): None})) + Composite( + a: Composite( b: None, - c: None)) - - CompositeSpec supports nested indexing: - >>> spec = CompositeSpec(obs=None) - >>> spec["nested", "x"] = None - >>> print(spec) - CompositeSpec( - nested: CompositeSpec( - x: None), - x: None) + c: None, + device=None, + shape=torch.Size([])), + device=None, + shape=torch.Size([])) """ @@ -3871,15 +4173,15 @@ def shape(self, value: torch.Size): if self.locked: raise RuntimeError("Cannot modify shape of locked composite spec.") for key, spec in self.items(): - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): if spec.shape[: len(value)] != value: spec.shape = value elif spec is not None: if spec.shape[: len(value)] != value: raise ValueError( - f"The shape of the spec and the CompositeSpec mismatch during shape resetting: the " + f"The shape of the spec and the Composite mismatch during shape resetting: the " f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and " - f"CompositeSpec.shape={self.shape}." + f"Composite.shape={self.shape}." ) self._shape = torch.Size(value) @@ -3896,24 +4198,29 @@ def ndimension(self): def set(self, name, spec): if self.locked: - raise RuntimeError("Cannot modify a locked CompositeSpec.") + raise RuntimeError("Cannot modify a locked Composite.") if spec is not None: shape = spec.shape if shape[: self.ndim] != self.shape: raise ValueError( - "The shape of the spec and the CompositeSpec mismatch: the first " + "The shape of the spec and the Composite mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " - f"CompositeSpec.shape={self.shape}." + f"Composite.shape={self.shape}." ) self._specs[name] = spec - def __init__(self, *args, shape=None, device=None, **kwargs): + def __init__( + self, *args, shape: torch.Size = None, device: torch.device = None, **kwargs + ): + # For compatibility with TensorDict + batch_size = kwargs.pop("batch_size", None) + if batch_size is not None: + if shape is not None: + raise TypeError("Cannot specify both batch_size and shape.") + shape = batch_size + if shape is None: - # Should we do this? Other specs have a default empty shape, maybe it would make sense to keep it - # optional for composite (for clarity and easiness of use). - # warnings.warn("shape=None for CompositeSpec will soon be deprecated. Make sure you set the " - # "batch size of your CompositeSpec as you would do for a tensordict.") - shape = [] + shape = torch.Size(()) self._shape = torch.Size(shape) self._specs = {} for key, value in kwargs.items(): @@ -3927,7 +4234,7 @@ def __init__(self, *args, shape=None, device=None, **kwargs): if item is None: continue if ( - isinstance(item, CompositeSpec) + isinstance(item, Composite) and item.device is None and _device is not None ): @@ -3936,22 +4243,22 @@ def __init__(self, *args, shape=None, device=None, **kwargs): raise RuntimeError( f"Setting a new attribute ({key}) on another device " f"({item.device} against {_device}). All devices of " - "CompositeSpec must match." + "Composite must match." ) self._device = _device if len(args): if len(args) > 1: raise RuntimeError( - "Got multiple arguments, when at most one is expected for CompositeSpec." + "Got multiple arguments, when at most one is expected for Composite." ) argdict = args[0] - if not isinstance(argdict, (dict, CompositeSpec)): + if not isinstance(argdict, (dict, Composite)): raise RuntimeError( f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." ) for k, item in argdict.items(): if isinstance(item, dict): - item = CompositeSpec(item, shape=shape, device=_device) + item = Composite(item, shape=shape, device=_device) self[k] = item @property @@ -3968,14 +4275,14 @@ def device(self, device: DEVICE_TYPING): self.to(device) def clear_device_(self): - """Clears the device of the CompositeSpec.""" + """Clears the device of the Composite.""" self._device = None for spec in self._specs.values(): spec.clear_device_() return self def __getitem__(self, idx): - """Indexes the current CompositeSpec based on the provided index.""" + """Indexes the current Composite based on the provided index.""" if isinstance(idx, (str, tuple)): idx_unravel = unravel_key(idx) else: @@ -3984,7 +4291,7 @@ def __getitem__(self, idx): if isinstance(idx_unravel, tuple): return self[idx[0]][idx[1:]] if idx_unravel in {"shape", "device", "dtype", "space"}: - raise AttributeError(f"CompositeSpec has no key {idx_unravel}") + raise AttributeError(f"Composite has no key {idx_unravel}") return self._specs[idx_unravel] indexed_shape = _shape_indexing(self.shape, idx) @@ -3996,9 +4303,9 @@ def __getitem__(self, idx): if any( isinstance(v, spec_class) for spec_class in [ - BinaryDiscreteTensorSpec, - MultiDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + MultiCategorical, + OneHot, ] ): protected_dims = 1 @@ -4020,7 +4327,7 @@ def __getitem__(self, idx): ) def get(self, item, default=NO_DEFAULT): - """Gets an item from the CompositeSpec. + """Gets an item from the Composite. If the item is absent, a default value can be passed. @@ -4035,7 +4342,7 @@ def get(self, item, default=NO_DEFAULT): def __setitem__(self, key, value): if isinstance(key, tuple) and len(key) > 1: if key[0] not in self.keys(True): - self[key[0]] = CompositeSpec(shape=self.shape, device=self.device) + self[key[0]] = Composite(shape=self.shape, device=self.device) self[key[0]][key[1:]] = value return elif isinstance(key, tuple): @@ -4044,20 +4351,20 @@ def __setitem__(self, key, value): elif not isinstance(key, str): raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: - raise AttributeError(f"CompositeSpec[{key}] cannot be set") + raise AttributeError(f"Composite[{key}] cannot be set") if isinstance(value, dict): - value = CompositeSpec(value, device=self._device, shape=self.shape) + value = Composite(value, device=self._device, shape=self.shape) if ( value is not None and self.device is not None and value.device != self.device ): - if isinstance(value, CompositeSpec) and value.device is None: + if isinstance(value, Composite) and value.device is None: value = value.clone().to(self.device) else: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " - f"All devices of CompositeSpec must match." + f"All devices of Composite must match." ) self.set(key, value) @@ -4090,13 +4397,13 @@ def encode( for key, item in vals.items(): if item is None: raise RuntimeError( - "CompositeSpec.encode cannot be used with missing values." + "Composite.encode cannot be used with missing values." ) try: out[key] = self[key].encode(item, ignore_device=ignore_device) except KeyError: raise KeyError( - f"The CompositeSpec instance with keys {self.keys()} does not have a '{key}' key." + f"The Composite instance with keys {self.keys()} does not have a '{key}' key." ) except RuntimeError as err: raise RuntimeError( @@ -4109,7 +4416,7 @@ def __repr__(self) -> str: indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) - return f"CompositeSpec(\n{sub_str},\n device={self._device},\n shape={self.shape})" + return f"Composite(\n{sub_str},\n device={self._device},\n shape={self.shape})" def type_check( self, @@ -4128,7 +4435,7 @@ def type_check( def is_in(self, val: Union[dict, TensorDictBase]) -> bool: for key, item in self._specs.items(): - if item is None or (isinstance(item, CompositeSpec) and item.is_empty()): + if item is None or (isinstance(item, Composite) and item.is_empty()): continue val_item = val.get(key, NO_DEFAULT) if not item.is_in(val_item): @@ -4144,7 +4451,7 @@ def project(self, val: TensorDictBase) -> TensorDictBase: val.set(key, self._specs[key].project(_val)) return val - def rand(self, shape=None) -> TensorDictBase: + def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: shape = torch.Size([]) _dict = {} @@ -4166,24 +4473,24 @@ def keys( *, is_leaf: Callable[[type], bool] | None = None, ) -> _CompositeSpecKeysView: # noqa: D417 - """Keys of the CompositeSpec. + """Keys of the Composite. The keys argument reflect those of :class:`tensordict.TensorDict`. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a - :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys + :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. Keyword Args: is_leaf (callable, optional): reads a type and returns a boolean indicating if that type - should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as + should be seen as a leaf. By default, all non-Composite nodes are considered as leaves. """ @@ -4201,22 +4508,22 @@ def items( *, is_leaf: Callable[[type], bool] | None = None, ) -> _CompositeSpecItemsView: # noqa: D417 - """Items of the CompositeSpec. + """Items of the Composite. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a - :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys + :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. Keyword Args: is_leaf (callable, optional): reads a type and returns a boolean indicating if that type - should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as + should be seen as a leaf. By default, all non-Composite nodes are considered as leaves. """ return _CompositeSpecItemsView( @@ -4233,22 +4540,22 @@ def values( *, is_leaf: Callable[[type], bool] | None = None, ) -> _CompositeSpecValuesView: # noqa: D417 - """Values of the CompositeSpec. + """Values of the Composite. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a - :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys + :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. Keyword Args: is_leaf (callable, optional): reads a type and returns a boolean indicating if that type - should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as + should be seen as a leaf. By default, all non-Composite nodes are considered as leaves. """ return _CompositeSpecItemsView( @@ -4263,7 +4570,7 @@ def _reshape(self, shape): key: val.reshape((*shape, *val.shape[self.ndimension() :])) for key, val in self._specs.items() } - return CompositeSpec(_specs, shape=shape) + return Composite(_specs, shape=shape) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape @@ -4272,12 +4579,12 @@ def _unflatten(self, dim, sizes): def __len__(self): return len(self.keys()) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite: if dest is None: return self if not isinstance(dest, (str, int, torch.device)): raise ValueError( - "Only device casting is allowed with specs of type CompositeSpec." + "Only device casting is allowed with specs of type Composite." ) if self._device and self._device == torch.device(dest): return self @@ -4292,7 +4599,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: kwargs[key] = value.to(dest) return self.__class__(**kwargs, device=_device, shape=self.shape) - def clone(self) -> CompositeSpec: + def clone(self) -> Composite: try: device = self.device except RuntimeError: @@ -4321,7 +4628,7 @@ def empty(self): def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} - def zero(self, shape=None) -> TensorDictBase: + def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: shape = torch.Size([]) try: @@ -4347,9 +4654,9 @@ def __eq__(self, other): and all((self._specs[key] == spec) for (key, spec) in other._specs.items()) ) - def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: + def update(self, dict_or_spec: Union[Composite, Dict[str, TensorSpec]]) -> None: for key, item in dict_or_spec.items(): - if key in self.keys(True) and isinstance(self[key], CompositeSpec): + if key in self.keys(True) and isinstance(self[key], Composite): self[key].update(item) continue try: @@ -4390,7 +4697,7 @@ def expand(self, *shape): else None for key, value in tuple(self.items()) } - out = CompositeSpec( + out = Composite( specs, shape=shape, device=device, @@ -4411,7 +4718,7 @@ def squeeze(self, dim: int | None = None): except RuntimeError: device = self._device - return CompositeSpec( + return Composite( {key: value.squeeze(dim) for key, value in self.items()}, shape=shape, device=device, @@ -4437,7 +4744,7 @@ def unsqueeze(self, dim: int): except RuntimeError: device = self._device - return CompositeSpec( + return Composite( { key: value.unsqueeze(dim) if value is not None else None for key, value in self.items() @@ -4466,19 +4773,19 @@ def unbind(self, dim: int = 0): ) def lock_(self, recurse=False): - """Locks the CompositeSpec and prevents modification of its content. + """Locks the Composite and prevents modification of its content. This is only a first-level lock, unless specified otherwise through the ``recurse`` arg. Leaf specs can always be modified in place, but they cannot be replaced - in their CompositeSpec parent. + in their Composite parent. Examples: >>> shape = [3, 4, 5] - >>> spec = CompositeSpec( - ... a=CompositeSpec( - ... b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2] + >>> spec = Composite( + ... a=Composite( + ... b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2] ... ), ... shape=shape[:1], ... ) @@ -4509,12 +4816,12 @@ def lock_(self, recurse=False): self._locked = True if recurse: for value in self.values(): - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): value.lock_(recurse) return self def unlock_(self, recurse=False): - """Unlocks the CompositeSpec and allows modification of its content. + """Unlocks the Composite and allows modification of its content. This is only a first-level lock modification, unless specified otherwise through the ``recurse`` arg. @@ -4523,7 +4830,7 @@ def unlock_(self, recurse=False): self._locked = False if recurse: for value in self.values(): - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): value.unlock_(recurse) return self @@ -4532,7 +4839,7 @@ def locked(self): return self._locked -class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): +class StackedComposite(_LazyStackedMixin[Composite], Composite): """A lazy representation of a stack of composite specs. Stacks composite specs together along one dimension. @@ -4548,7 +4855,7 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): def update(self, dict) -> None: for key, item in dict.items(): if key in self.keys() and isinstance( - item, (Dict, CompositeSpec, LazyStackedCompositeSpec) + item, (Dict, Composite, StackedComposite) ): for spec, sub_item in zip(self._specs, item.unbind(self.dim)): spec[key].update(sub_item) @@ -4557,7 +4864,7 @@ def update(self, dict) -> None: return self def __eq__(self, other): - if not isinstance(other, LazyStackedCompositeSpec): + if not isinstance(other, StackedComposite): return False if len(self._specs) != len(other._specs): return False @@ -4576,7 +4883,7 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: if safe: if val.shape[self.dim] != len(self._specs): raise ValueError( - "Size of LazyStackedCompositeSpec and val differ along the " + "Size of StackedComposite and val differ along the " "stacking dimension" ) for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): @@ -4674,7 +4981,7 @@ def __repr__(self) -> str: string = ",\n".join( [sub_str, exclusive_key_str, device_str, shape_str, stack_dim] ) - return f"LazyStackedCompositeSpec(\n{string})" + return f"StackedComposite(\n{string})" def repr_exclusive_keys(self): keys = set(self.keys()) @@ -4812,7 +5119,7 @@ def expand(self, *shape): ) def empty(self): - return LazyStackedCompositeSpec.maybe_dense_stack( + return StackedComposite.maybe_dense_stack( [spec.empty() for spec in self._specs], dim=self.stack_dim ) @@ -4821,7 +5128,7 @@ def encode( ) -> Dict[str, torch.Tensor]: raise NOT_IMPLEMENTED_ERROR - def zero(self, shape=None) -> TensorDictBase: + def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -4830,7 +5137,7 @@ def zero(self, shape=None) -> TensorDictBase: [spec.zero(shape) for spec in self._specs], dim ) - def one(self, shape=None) -> TensorDictBase: + def one(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -4839,7 +5146,7 @@ def one(self, shape=None) -> TensorDictBase: [spec.one(shape) for spec in self._specs], dim ) - def rand(self, shape=None) -> TensorDictBase: + def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -4849,7 +5156,6 @@ def rand(self, shape=None) -> TensorDictBase: ) -# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack) def _stack_specs(list_of_spec, dim, out=None): if out is not None: @@ -4882,12 +5188,12 @@ def _stack_specs(list_of_spec, dim, out=None): dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) - return LazyStackedTensorSpec(*list_of_spec, dim=dim) + return Stacked(*list_of_spec, dim=dim) else: raise NotImplementedError -@CompositeSpec.implements_for_spec(torch.stack) +@Composite.implements_for_spec(torch.stack) def _stack_composite_specs(list_of_spec, dim, out=None): if out is not None: raise NotImplementedError( @@ -4897,7 +5203,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None): if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] - if isinstance(spec0, CompositeSpec): + if isinstance(spec0, Composite): devices = {spec.device for spec in list_of_spec} if len(devices) == 1: device = list(devices)[0] @@ -4912,7 +5218,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None): all_equal = True for spec in list_of_spec[1:]: - if not isinstance(spec, CompositeSpec): + if not isinstance(spec, Composite): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of spec in " "the list." @@ -4929,7 +5235,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None): dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) - return LazyStackedCompositeSpec(*list_of_spec, dim=dim) + return StackedComposite(*list_of_spec, dim=dim) else: raise NotImplementedError @@ -4939,8 +5245,8 @@ def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: return spec.squeeze(*args, **kwargs) -@CompositeSpec.implements_for_spec(torch.squeeze) -def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: +@Composite.implements_for_spec(torch.squeeze) +def _squeeze_composite_spec(spec: Composite, *args, **kwargs) -> Composite: return spec.squeeze(*args, **kwargs) @@ -4949,16 +5255,16 @@ def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: return spec.unsqueeze(*args, **kwargs) -@CompositeSpec.implements_for_spec(torch.unsqueeze) -def _unsqueeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: +@Composite.implements_for_spec(torch.unsqueeze) +def _unsqueeze_composite_spec(spec: Composite, *args, **kwargs) -> Composite: return spec.unsqueeze(*args, **kwargs) def _keys_to_empty_composite_spec(keys): - """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value.""" + """Given a list of keys, creates a Composite tree where each leaf is assigned a None value.""" if not len(keys): return - c = CompositeSpec() + c = Composite() for key in keys: if isinstance(key, str): c[key] = None @@ -4966,7 +5272,7 @@ def _keys_to_empty_composite_spec(keys): if c[key[0]] is None: # if the value is None we just replace it c[key[0]] = _keys_to_empty_composite_spec([key[1:]]) - elif isinstance(c[key[0]], CompositeSpec): + elif isinstance(c[key[0]], Composite): # if the value is Composite, we update it out = _keys_to_empty_composite_spec([key[1:]]) if out is not None: @@ -5010,11 +5316,11 @@ def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: class _CompositeSpecItemsView: - """Wrapper class that enables richer behaviour of `items` for CompositeSpec.""" + """Wrapper class that enables richer behaviour of `items` for Composite.""" def __init__( self, - composite: CompositeSpec, + composite: Composite, include_nested, leaves_only, *, @@ -5032,13 +5338,13 @@ def __iter__(self): if is_leaf in (None, _NESTED_TENSORS_AS_LISTS): def _is_leaf(cls): - return not issubclass(cls, CompositeSpec) + return not issubclass(cls, Composite) else: _is_leaf = is_leaf def _iter_from_item(key, item): - if self.include_nested and isinstance(item, CompositeSpec): + if self.include_nested and isinstance(item, Composite): for subkey, subitem in item.items( include_nested=True, leaves_only=self.leaves_only, @@ -5063,7 +5369,7 @@ def _iter_from_item(key, item): def _get_composite_items(self, is_leaf): - if isinstance(self.composite, LazyStackedCompositeSpec): + if isinstance(self.composite, StackedComposite): from tensordict.base import _NESTED_TENSORS_AS_LISTS if is_leaf is _NESTED_TENSORS_AS_LISTS: @@ -5150,5 +5456,149 @@ def _minmax_dtype(dtype): def _remove_neg_shapes(*shape): if len(shape) == 1 and not isinstance(shape[0], int): - return _remove_neg_shapes(*shape[0]) + shape = shape[0] + if isinstance(shape, np.integer): + shape = (int(shape),) + return _remove_neg_shapes(*shape) return torch.Size([int(d) if d >= 0 else 1 for d in shape]) + + +############## +# Legacy +# +class _LegacySpecMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + warnings.warn( + f"The {cls.__name__} has been deprecated and will be removed in v0.7. Please use " + f"{cls.__bases__[-1].__name__} instead.", + category=DeprecationWarning, + ) + instance = super().__call__(*args, **kwargs) + if ( + type(instance) in (UnboundedDiscreteTensorSpec, UnboundedDiscrete) + and instance.domain == "continuous" + ): + instance.__class__ = UnboundedContinuous + elif ( + type(instance) in (UnboundedContinuousTensorSpec, UnboundedContinuous) + and instance.domain == "discrete" + ): + instance.__class__ = UnboundedDiscrete + return instance + + def __instancecheck__(cls, instance): + check0 = super().__instancecheck__(instance) + if check0: + return True + parent_cls = cls.__bases__[-1] + return isinstance(instance, parent_cls) + + +class CompositeSpec(Composite, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Composite`.""" + + ... + + +class OneHotDiscreteTensorSpec(OneHot, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.OneHot`.""" + + ... + + +class MultiOneHotDiscreteTensorSpec(MultiOneHot, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.MultiOneHot`.""" + + ... + + +class NonTensorSpec(NonTensor, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.NonTensor`.""" + + ... + + +class MultiDiscreteTensorSpec(MultiCategorical, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.MultiCategorical`.""" + + ... + + +class LazyStackedTensorSpec(Stacked, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Stacked`.""" + + ... + + +class LazyStackedCompositeSpec(StackedComposite, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.StackedComposite`.""" + + ... + + +class DiscreteTensorSpec(Categorical, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Categorical`.""" + + ... + + +class BinaryDiscreteTensorSpec(Binary, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Binary`.""" + + ... + + +_BoundedLegacyMeta = type("_BoundedLegacyMeta", (_LegacySpecMeta, _BoundedMeta), {}) + + +class BoundedTensorSpec(Bounded, metaclass=_BoundedLegacyMeta): + """Deprecated version of :class:`torchrl.data.Bounded`.""" + + ... + + +class _UnboundedContinuousMetaclass(_UnboundedMeta): + def __instancecheck__(cls, instance): + return isinstance(instance, Unbounded) and instance.domain == "continuous" + + +_LegacyUnboundedContinuousMetaclass = type( + "_LegacyUnboundedDiscreteMetaclass", + (_UnboundedContinuousMetaclass, _LegacySpecMeta), + {}, +) + + +class UnboundedContinuousTensorSpec( + Unbounded, metaclass=_LegacyUnboundedContinuousMetaclass +): + """Deprecated version of :class:`torchrl.data.Unbounded` with continuous space.""" + + ... + + +class _UnboundedDiscreteMetaclass(_UnboundedMeta): + def __instancecheck__(cls, instance): + return isinstance(instance, Unbounded) and instance.domain == "discrete" + + +_LegacyUnboundedDiscreteMetaclass = type( + "_LegacyUnboundedDiscreteMetaclass", + (_UnboundedDiscreteMetaclass, _LegacySpecMeta), + {}, +) + + +class UnboundedDiscreteTensorSpec( + Unbounded, metaclass=_LegacyUnboundedDiscreteMetaclass +): + """Deprecated version of :class:`torchrl.data.Unbounded` with discrete space.""" + + def __init__( + self, + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.int64, + **kwargs, + ): + super().__init__(shape=shape, device=device, dtype=dtype, **kwargs) diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index fb4ec30daed..214c79b4686 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -13,14 +13,14 @@ from torch import Tensor from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - LazyStackedCompositeSpec, - LazyStackedTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + OneHot, + Stacked, + StackedComposite, TensorSpec, ) @@ -50,10 +50,10 @@ ACTION_SPACE_MAP = { - OneHotDiscreteTensorSpec: "one_hot", - MultiOneHotDiscreteTensorSpec: "mult_one_hot", - BinaryDiscreteTensorSpec: "binary", - DiscreteTensorSpec: "categorical", + OneHot: "one_hot", + MultiOneHot: "mult_one_hot", + Binary: "binary", + Categorical: "categorical", "one_hot": "one_hot", "one-hot": "one_hot", "mult_one_hot": "mult_one_hot", @@ -62,7 +62,7 @@ "multi-one-hot": "mult_one_hot", "binary": "binary", "categorical": "categorical", - MultiDiscreteTensorSpec: "multi_categorical", + MultiCategorical: "multi_categorical", "multi_categorical": "multi_categorical", "multi-categorical": "multi_categorical", "multi_discrete": "multi_categorical", @@ -71,14 +71,14 @@ def consolidate_spec( - spec: CompositeSpec, + spec: Composite, recurse_through_entries: bool = True, recurse_through_stack: bool = True, ): """Given a TensorSpec, removes exclusive keys by adding 0 shaped specs. Args: - spec (CompositeSpec): the spec to be consolidated. + spec (Composite): the spec to be consolidated. recurse_through_entries (bool): if True, call the function recursively on all entries of the spec. Default is True. recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively @@ -87,10 +87,10 @@ def consolidate_spec( """ spec = spec.clone() - if not isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)): + if not isinstance(spec, (Composite, StackedComposite)): return spec - if isinstance(spec, LazyStackedCompositeSpec): + if isinstance(spec, StackedComposite): keys = set(spec.keys()) # shared keys exclusive_keys_per_spec = [ set() for _ in range(len(spec._specs)) @@ -128,7 +128,7 @@ def consolidate_spec( if recurse_through_entries: for key, value in spec.items(): - if isinstance(value, (CompositeSpec, LazyStackedCompositeSpec)): + if isinstance(value, (Composite, StackedComposite)): spec.set( key, consolidate_spec( @@ -145,16 +145,16 @@ def _empty_like_spec(specs: List[TensorSpec], shape): "Found same key in lazy specs corresponding to entries with different classes" ) spec = specs[0] - if isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)): + if isinstance(spec, (Composite, StackedComposite)): # the exclusive key has values which are CompositeSpecs -> # we create an empty composite spec with same batch size return spec.empty() - elif isinstance(spec, LazyStackedTensorSpec): + elif isinstance(spec, Stacked): # the exclusive key has values which are LazyStackedTensorSpecs -> # we create a LazyStackedTensorSpec with the same shape (aka same -1s) as the first in the list. # this will not add any new -1s when they are stacked shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :]) - return LazyStackedTensorSpec( + return Stacked( *[_empty_like_spec(spec._specs, shape) for _ in spec._specs], dim=spec.stack_dim, ) @@ -191,14 +191,14 @@ def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True): spec (TensorSpec): the spec to check recurse (bool): if True, check recursively in nested specs. Default is True. """ - if isinstance(spec, LazyStackedCompositeSpec): + if isinstance(spec, StackedComposite): keys = set(spec.keys()) for inner_td in spec._specs: if recurse and not check_no_exclusive_keys(inner_td): return False if set(inner_td.keys()) != keys: return False - elif isinstance(spec, CompositeSpec) and recurse: + elif isinstance(spec, Composite) and recurse: for value in spec.values(): if not check_no_exclusive_keys(value): return False @@ -214,9 +214,9 @@ def contains_lazy_spec(spec: TensorSpec) -> bool: spec (TensorSpec): the spec to check """ - if isinstance(spec, (LazyStackedTensorSpec, LazyStackedCompositeSpec)): + if isinstance(spec, (Stacked, StackedComposite)): return True - elif isinstance(spec, CompositeSpec): + elif isinstance(spec, Composite): for inner_spec in spec.values(): if contains_lazy_spec(inner_spec): return True @@ -253,7 +253,7 @@ def __call__(self, *args, **kwargs) -> Any: def _process_action_space_spec(action_space, spec): original_spec = spec composite_spec = False - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): # this will break whenever our action is more complex than a single tensor try: if "action" in spec.keys(): @@ -274,8 +274,8 @@ def _process_action_space_spec(action_space, spec): "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." ) if action_space is not None: - if isinstance(action_space, CompositeSpec): - raise ValueError("action_space cannot be of type CompositeSpec.") + if isinstance(action_space, Composite): + raise ValueError("action_space cannot be of type Composite.") if ( spec is not None and isinstance(action_space, TensorSpec) @@ -305,7 +305,7 @@ def _process_action_space_spec(action_space, spec): def _find_action_space(action_space): if isinstance(action_space, TensorSpec): - if isinstance(action_space, CompositeSpec): + if isinstance(action_space, Composite): if "action" in action_space.keys(): _key = "action" else: diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 4996e527527..f915af52bcc 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -36,7 +36,7 @@ logger as torchrl_logger, VERBOSE, ) -from torchrl.data.tensor_specs import CompositeSpec, NonTensorSpec +from torchrl.data.tensor_specs import Composite, NonTensor from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData from torchrl.envs.env_creator import get_env_metadata @@ -550,7 +550,7 @@ def _set_properties(self): cls = type(self) - def _check_for_empty_spec(specs: CompositeSpec): + def _check_for_empty_spec(specs: Composite): for subspec in ( "full_state_spec", "full_action_spec", @@ -559,9 +559,9 @@ def _check_for_empty_spec(specs: CompositeSpec): "full_observation_spec", ): for key, spec in reversed( - list(specs.get(subspec, default=CompositeSpec()).items(True)) + list(specs.get(subspec, default=Composite()).items(True)) ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): raise RuntimeError( f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using " f"torchrl.envs.transforms.RemoveEmptySpecs to remove the empty specs." @@ -675,7 +675,7 @@ def _create_td(self) -> None: self.full_done_spec, ): for key, _spec in spec.items(True, True): - if isinstance(_spec, NonTensorSpec): + if isinstance(_spec, NonTensor): non_tensor_keys.append(key) self._non_tensor_keys = non_tensor_keys diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9216b58e86..3277158af57 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -25,12 +25,7 @@ seed_generator, ) -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _make_compatible_policy, @@ -62,7 +57,7 @@ def __init__( self, *, tensordict: TensorDictBase, - specs: CompositeSpec, + specs: Composite, batch_size: torch.Size, env_str: str, device: torch.device, @@ -91,7 +86,7 @@ def tensordict(self, value: TensorDictBase): self._tensordict = value.to("cpu") @specs.setter - def specs(self, value: CompositeSpec): + def specs(self, value: Composite): self._specs = value.to("cpu") @staticmethod @@ -212,29 +207,29 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): be done after a call to :meth:`~.reset` is made. Defaults to ``False``. Attributes: - done_spec (CompositeSpec): equivalent to ``full_done_spec`` as all + done_spec (Composite): equivalent to ``full_done_spec`` as all ``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf action if only one action tensor is to be expected. Otherwise links to ``full_action_spec``. - observation_spec (CompositeSpec): equivalent to ``full_observation_spec``. + observation_spec (Composite): equivalent to ``full_observation_spec``. reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf reward if only one reward tensor is to be expected. Otherwise links to ``full_reward_spec``. - state_spec (CompositeSpec): equivalent to ``full_state_spec``. - full_done_spec (CompositeSpec): a composite spec such that ``full_done_spec.zero()`` + state_spec (Composite): equivalent to ``full_state_spec``. + full_done_spec (Composite): a composite spec such that ``full_done_spec.zero()`` returns a tensordict containing only the leaves encoding the done status of the environment. - full_action_spec (CompositeSpec): a composite spec such that ``full_action_spec.zero()`` + full_action_spec (Composite): a composite spec such that ``full_action_spec.zero()`` returns a tensordict containing only the leaves encoding the action of the environment. - full_observation_spec (CompositeSpec): a composite spec such that ``full_observation_spec.zero()`` + full_observation_spec (Composite): a composite spec such that ``full_observation_spec.zero()`` returns a tensordict containing only the leaves encoding the observation of the environment. - full_reward_spec (CompositeSpec): a composite spec such that ``full_reward_spec.zero()`` + full_reward_spec (Composite): a composite spec such that ``full_reward_spec.zero()`` returns a tensordict containing only the leaves encoding the reward of the environment. - full_state_spec (CompositeSpec): a composite spec such that ``full_state_spec.zero()`` + full_state_spec (Composite): a composite spec such that ``full_state_spec.zero()`` returns a tensordict containing only the leaves encoding the inputs (actions excluded) of the environment. batch_size (torch.Size): The batch-size of the environment. @@ -253,9 +248,9 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): >>> from torchrl.envs import EnvBase >>> class CounterEnv(EnvBase): ... def __init__(self, batch_size=(), device=None, **kwargs): - ... self.observation_spec = CompositeSpec( - ... count=UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int64)) - ... self.action_spec = UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int8) + ... self.observation_spec = Composite( + ... count=Unbounded(batch_size, device=device, dtype=torch.int64)) + ... self.action_spec = Unbounded(batch_size, device=device, dtype=torch.int8) ... # done spec and reward spec are set automatically ... def _step(self, tensordict): ... @@ -264,10 +259,10 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): >>> env.batch_size # how many envs are run at once torch.Size([]) >>> env.input_spec - CompositeSpec( + Composite( full_state_spec: None, - full_action_spec: CompositeSpec( - action: BoundedTensorSpec( + full_action_spec: Composite( + action: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -276,7 +271,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> env.action_spec - BoundedTensorSpec( + BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -285,8 +280,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): dtype=torch.float32, domain=continuous) >>> env.observation_spec - CompositeSpec( - observation: BoundedTensorSpec( + Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -295,14 +290,14 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])) >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous) >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -310,16 +305,16 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): domain=discrete) >>> # the output_spec contains all the expected outputs >>> env.output_spec - CompositeSpec( - full_reward_spec: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + full_reward_spec: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_observation_spec: CompositeSpec( - observation: BoundedTensorSpec( + full_observation_spec: Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -327,8 +322,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -544,10 +539,10 @@ def input_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.input_spec - CompositeSpec( + Composite( full_state_spec: None, - full_action_spec: CompositeSpec( - action: BoundedTensorSpec( + full_action_spec: Composite( + action: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -560,7 +555,7 @@ def input_spec(self) -> TensorSpec: """ input_spec = self.__dict__.get("_input_spec") if input_spec is None: - input_spec = CompositeSpec( + input_spec = Composite( full_state_spec=None, shape=self.batch_size, device=self.device, @@ -591,16 +586,16 @@ def output_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.output_spec - CompositeSpec( - full_reward_spec: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + full_reward_spec: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_observation_spec: CompositeSpec( - observation: BoundedTensorSpec( + full_observation_spec: Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -608,8 +603,8 @@ def output_spec(self) -> TensorSpec: device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -620,7 +615,7 @@ def output_spec(self) -> TensorSpec: """ output_spec = self.__dict__.get("_output_spec") if output_spec is None: - output_spec = CompositeSpec( + output_spec = Composite( shape=self.batch_size, device=self.device, ).lock_() @@ -688,9 +683,9 @@ def action_spec(self) -> TensorSpec: If the action spec is provided as a simple spec, this will be returned. - >>> env.action_spec = UnboundedContinuousTensorSpec(1) + >>> env.action_spec = Unbounded(1) >>> env.action_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -702,9 +697,9 @@ def action_spec(self) -> TensorSpec: If the action spec is provided as a composite spec and contains only one leaf, this function will return just the leaf. - >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1)}}) + >>> env.action_spec = Composite({"nested": {"action": Unbounded(1)}}) >>> env.action_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -716,11 +711,11 @@ def action_spec(self) -> TensorSpec: If the action spec is provided as a composite spec and has more than one leaf, this function will return the whole spec. - >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1), "another_action": DiscreteTensorSpec(1)}}) + >>> env.action_spec = Composite({"nested": {"action": Unbounded(1), "another_action": Categorical(1)}}) >>> env.action_spec - CompositeSpec( - nested: CompositeSpec( - action: UnboundedContinuousTensorSpec( + Composite( + nested: Composite( + action: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -728,7 +723,7 @@ def action_spec(self) -> TensorSpec: device=cpu, dtype=torch.float32, domain=continuous), - another_action: DiscreteTensorSpec( + another_action: Categorical( shape=torch.Size([]), space=DiscreteBox(n=1), device=cpu, @@ -745,7 +740,7 @@ def action_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.action_spec - BoundedTensorSpec( + BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -794,16 +789,16 @@ def action_spec(self, value: TensorSpec) -> None: f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( - "An empty CompositeSpec was passed for the action spec. " + "An empty Composite was passed for the action spec. " "This is currently not permitted." ) else: - value = CompositeSpec( + value = Composite( action=value.to(device), shape=self.batch_size, device=device ) @@ -812,10 +807,10 @@ def action_spec(self, value: TensorSpec) -> None: self.input_spec.lock_() @property - def full_action_spec(self) -> CompositeSpec: + def full_action_spec(self) -> Composite: """The full action spec. - ``full_action_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_action_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the action entries. Examples: @@ -824,8 +819,8 @@ def full_action_spec(self) -> CompositeSpec: ... break >>> env = BraxEnv(envname) >>> env.full_action_spec - CompositeSpec( - action: BoundedTensorSpec( + Composite( + action: BoundedContinuous( shape=torch.Size([8]), space=ContinuousBox( low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True), @@ -838,7 +833,7 @@ def full_action_spec(self) -> CompositeSpec: return self.input_spec["full_action_spec"] @full_action_spec.setter - def full_action_spec(self, spec: CompositeSpec) -> None: + def full_action_spec(self, spec: Composite) -> None: self.action_spec = spec # Reward spec @@ -881,9 +876,9 @@ def reward_spec(self) -> TensorSpec: If the reward spec is provided as a simple spec, this will be returned. - >>> env.reward_spec = UnboundedContinuousTensorSpec(1) + >>> env.reward_spec = Unbounded(1) >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -895,9 +890,9 @@ def reward_spec(self) -> TensorSpec: If the reward spec is provided as a composite spec and contains only one leaf, this function will return just the leaf. - >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1)}}) + >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1)}}) >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -909,11 +904,11 @@ def reward_spec(self) -> TensorSpec: If the reward spec is provided as a composite spec and has more than one leaf, this function will return the whole spec. - >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1), "another_reward": DiscreteTensorSpec(1)}}) + >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1), "another_reward": Categorical(1)}}) >>> env.reward_spec - CompositeSpec( - nested: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + nested: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -921,7 +916,7 @@ def reward_spec(self) -> TensorSpec: device=cpu, dtype=torch.float32, domain=continuous), - another_reward: DiscreteTensorSpec( + another_reward: Categorical( shape=torch.Size([]), space=DiscreteBox(n=1), device=cpu, @@ -938,7 +933,7 @@ def reward_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, @@ -952,7 +947,7 @@ def reward_spec(self) -> TensorSpec: # this will be raised if there is not full_reward_spec (unlikely) or no reward_key # Since output_spec is lazily populated with an empty composite spec for # reward_spec, the second case is much more likely to occur. - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=(*self.batch_size, 1), device=self.device, ) @@ -982,16 +977,16 @@ def reward_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( - "An empty CompositeSpec was passed for the reward spec. " + "An empty Composite was passed for the reward spec. " "This is currently not permitted." ) else: - value = CompositeSpec( + value = Composite( reward=value.to(device), shape=self.batch_size, device=device ) for leaf in value.values(True, True): @@ -1007,10 +1002,10 @@ def reward_spec(self, value: TensorSpec) -> None: self.output_spec.lock_() @property - def full_reward_spec(self) -> CompositeSpec: + def full_reward_spec(self) -> Composite: """The full reward spec. - ``full_reward_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_reward_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the reward entries. Examples: @@ -1019,9 +1014,9 @@ def full_reward_spec(self) -> CompositeSpec: >>> base_env = GymWrapper(gymnasium.make("Pendulum-v1")) >>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward"))) >>> env.full_reward_spec - CompositeSpec( - nested: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + nested: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -1034,7 +1029,7 @@ def full_reward_spec(self) -> CompositeSpec: return self.output_spec["full_reward_spec"] @full_reward_spec.setter - def full_reward_spec(self, spec: CompositeSpec) -> None: + def full_reward_spec(self, spec: Composite) -> None: self.reward_spec = spec.to(self.device) if self.device is not None else spec # done spec @@ -1068,10 +1063,10 @@ def done_key(self): return self.done_keys[0] @property - def full_done_spec(self) -> CompositeSpec: + def full_done_spec(self) -> Composite: """The full done spec. - ``full_done_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_done_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the done entries. It can be used to generate fake data with a structure that mimics the one obtained at runtime. @@ -1081,14 +1076,14 @@ def full_done_spec(self) -> CompositeSpec: >>> from torchrl.envs import GymWrapper >>> env = GymWrapper(gymnasium.make("Pendulum-v1")) >>> env.full_done_spec - CompositeSpec( - done: DiscreteTensorSpec( + Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, dtype=torch.bool, domain=discrete), - truncated: DiscreteTensorSpec( + truncated: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -1099,7 +1094,7 @@ def full_done_spec(self) -> CompositeSpec: return self.output_spec["full_done_spec"] @full_done_spec.setter - def full_done_spec(self, spec: CompositeSpec) -> None: + def full_done_spec(self, spec: Composite) -> None: self.done_spec = spec.to(self.device) if self.device is not None else spec # Done spec: done specs belong to output_spec @@ -1111,9 +1106,9 @@ def done_spec(self) -> TensorSpec: If the done spec is provided as a simple spec, this will be returned. - >>> env.done_spec = DiscreteTensorSpec(2, dtype=torch.bool) + >>> env.done_spec = Categorical(2, dtype=torch.bool) >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, @@ -1123,9 +1118,9 @@ def done_spec(self) -> TensorSpec: If the done spec is provided as a composite spec and contains only one leaf, this function will return just the leaf. - >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool)}}) + >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool)}}) >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, @@ -1135,17 +1130,17 @@ def done_spec(self) -> TensorSpec: If the done spec is provided as a composite spec and has more than one leaf, this function will return the whole spec. - >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool), "another_done": DiscreteTensorSpec(2, dtype=torch.bool)}}) + >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool), "another_done": Categorical(2, dtype=torch.bool)}}) >>> env.done_spec - CompositeSpec( - nested: CompositeSpec( - done: DiscreteTensorSpec( + Composite( + nested: Composite( + done: Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, dtype=torch.bool, domain=discrete), - another_done: DiscreteTensorSpec( + another_done: Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, @@ -1162,7 +1157,7 @@ def done_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -1185,16 +1180,16 @@ def _create_done_specs(self): try: full_done_spec = self.output_spec["full_done_spec"] except KeyError: - full_done_spec = CompositeSpec( + full_done_spec = Composite( shape=self.output_spec.shape, device=self.output_spec.device ) - full_done_spec["done"] = DiscreteTensorSpec( + full_done_spec["done"] = Categorical( n=2, shape=(*full_done_spec.shape, 1), dtype=torch.bool, device=self.device, ) - full_done_spec["terminated"] = DiscreteTensorSpec( + full_done_spec["terminated"] = Categorical( n=2, shape=(*full_done_spec.shape, 1), dtype=torch.bool, @@ -1215,7 +1210,7 @@ def check_local_done(spec): spec["terminated"] = item.clone() elif key == "terminated" and "done" not in spec.keys(): spec["done"] = item.clone() - elif isinstance(item, CompositeSpec): + elif isinstance(item, Composite): check_local_done(item) else: if shape is None: @@ -1229,10 +1224,10 @@ def check_local_done(spec): # if the spec is empty, we need to add a done and terminated manually if spec.is_empty(): - spec["done"] = DiscreteTensorSpec( + spec["done"] = Categorical( n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) - spec["terminated"] = DiscreteTensorSpec( + spec["terminated"] = Categorical( n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) @@ -1260,16 +1255,16 @@ def done_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( - "An empty CompositeSpec was passed for the done spec. " + "An empty Composite was passed for the done spec. " "This is currently not permitted." ) else: - value = CompositeSpec( + value = Composite( done=value.to(device), terminated=value.to(device), shape=self.batch_size, @@ -1290,10 +1285,10 @@ def done_spec(self, value: TensorSpec) -> None: # observation spec: observation specs belong to output_spec @property - def observation_spec(self) -> CompositeSpec: + def observation_spec(self) -> Composite: """Observation spec. - Must be a :class:`torchrl.data.CompositeSpec` instance. + Must be a :class:`torchrl.data.Composite` instance. The keys listed in the spec are directly accessible after reset and step. In TorchRL, even though they are not properly speaking "observations" @@ -1307,8 +1302,8 @@ def observation_spec(self) -> CompositeSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.observation_spec - CompositeSpec( - observation: BoundedTensorSpec( + Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -1320,7 +1315,7 @@ def observation_spec(self) -> CompositeSpec: """ observation_spec = self.output_spec["full_observation_spec"] if observation_spec is None: - observation_spec = CompositeSpec(shape=self.batch_size, device=self.device) + observation_spec = Composite(shape=self.batch_size, device=self.device) self.output_spec.unlock_() self.output_spec["full_observation_spec"] = observation_spec self.output_spec.lock_() @@ -1330,7 +1325,7 @@ def observation_spec(self) -> CompositeSpec: def observation_spec(self, value: TensorSpec) -> None: try: self.output_spec.unlock_() - if not isinstance(value, CompositeSpec): + if not isinstance(value, Composite): raise TypeError("The type of an observation_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( @@ -1348,19 +1343,19 @@ def observation_spec(self, value: TensorSpec) -> None: self.output_spec.lock_() @property - def full_observation_spec(self) -> CompositeSpec: + def full_observation_spec(self) -> Composite: return self.observation_spec @full_observation_spec.setter - def full_observation_spec(self, spec: CompositeSpec): + def full_observation_spec(self, spec: Composite): self.observation_spec = spec # state spec: state specs belong to input_spec @property - def state_spec(self) -> CompositeSpec: + def state_spec(self) -> Composite: """State spec. - Must be a :class:`torchrl.data.CompositeSpec` instance. + Must be a :class:`torchrl.data.Composite` instance. The keys listed here should be provided as input alongside actions to the environment. In TorchRL, even though they are not properly speaking "state" @@ -1376,10 +1371,10 @@ def state_spec(self) -> CompositeSpec: ... break >>> env = BraxEnv(envname) >>> env.state_spec - CompositeSpec( - state: CompositeSpec( - pipeline_state: CompositeSpec( - q: UnboundedContinuousTensorSpec( + Composite( + state: Composite( + pipeline_state: Composite( + q: UnboundedContinuous( shape=torch.Size([15]), space=None, device=cpu, @@ -1391,14 +1386,14 @@ def state_spec(self) -> CompositeSpec: """ state_spec = self.input_spec["full_state_spec"] if state_spec is None: - state_spec = CompositeSpec(shape=self.batch_size, device=self.device) + state_spec = Composite(shape=self.batch_size, device=self.device) self.input_spec.unlock_() self.input_spec["full_state_spec"] = state_spec self.input_spec.lock_() return state_spec @state_spec.setter - def state_spec(self, value: CompositeSpec) -> None: + def state_spec(self, value: Composite) -> None: try: self.input_spec.unlock_() try: @@ -1406,12 +1401,12 @@ def state_spec(self, value: CompositeSpec) -> None: except AttributeError: pass if value is None: - self.input_spec["full_state_spec"] = CompositeSpec( + self.input_spec["full_state_spec"] = Composite( device=self.device, shape=self.batch_size ) else: device = self.input_spec.device - if not isinstance(value, CompositeSpec): + if not isinstance(value, Composite): raise TypeError("The type of an state_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( @@ -1428,10 +1423,10 @@ def state_spec(self, value: CompositeSpec) -> None: self.input_spec.lock_() @property - def full_state_spec(self) -> CompositeSpec: + def full_state_spec(self) -> Composite: """The full state spec. - ``full_state_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_state_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the state entries (ie, the input data that is not action). Examples: @@ -1440,10 +1435,10 @@ def full_state_spec(self) -> CompositeSpec: ... break >>> env = BraxEnv(envname) >>> env.full_state_spec - CompositeSpec( - state: CompositeSpec( - pipeline_state: CompositeSpec( - q: UnboundedContinuousTensorSpec( + Composite( + state: Composite( + pipeline_state: Composite( + q: UnboundedContinuous( shape=torch.Size([15]), space=None, device=cpu, @@ -1455,7 +1450,7 @@ def full_state_spec(self) -> CompositeSpec: return self.state_spec @full_state_spec.setter - def full_state_spec(self, spec: CompositeSpec) -> None: + def full_state_spec(self, spec: Composite) -> None: self.state_spec = spec def step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -1494,7 +1489,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: @classmethod def _complete_done( - cls, done_spec: CompositeSpec, data: TensorDictBase + cls, done_spec: Composite, data: TensorDictBase ) -> TensorDictBase: """Completes the data structure at step time to put missing done keys.""" # by default, if a done key is missing, it is assumed that it is False @@ -1508,7 +1503,7 @@ def _complete_done( i = -1 for i, (key, item) in enumerate(done_spec.items()): # noqa: B007 val = data.get(key, None) - if isinstance(item, CompositeSpec): + if isinstance(item, Composite): if val is not None: cls._complete_done(item, val) continue @@ -2300,14 +2295,14 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa return self.step(tensordict) @property - def specs(self) -> CompositeSpec: + def specs(self) -> Composite: """Returns a Composite container where all the environment are present. This feature allows one to create an environment, retrieve all of the specs in a single data container and then erase the environment from the workspace. """ - return CompositeSpec( + return Composite( output_spec=self.output_spec, input_spec=self.input_spec, shape=self.batch_size, @@ -3169,7 +3164,7 @@ def _do_nothing(): return -def _has_dynamic_specs(spec: CompositeSpec): +def _has_dynamic_specs(spec: Composite): from tensordict.base import _NESTED_TENSORS_AS_LISTS return any( diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index 8253e3df9b7..f785d1cedd9 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -6,11 +6,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs.common import EnvBase from torchrl.envs.utils import make_composite_from_td @@ -21,125 +17,193 @@ class PendulumEnv(EnvBase): See the Pendulum tutorial for more details: :ref:`tutorial `. Specs: - CompositeSpec( - output_spec: CompositeSpec( - full_observation_spec: CompositeSpec( - th: BoundedTensorSpec( + >>> env = PendulumEnv() + >>> env.specs + Composite( + output_spec: Composite( + full_observation_spec: Composite( + th: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - thdot: BoundedTensorSpec( + thdot: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - params: CompositeSpec( - max_speed: UnboundedContinuousTensorSpec( + params: Composite( + max_speed: UnboundedDiscrete( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)), + device=cpu, dtype=torch.int64, domain=discrete), - max_torque: UnboundedContinuousTensorSpec( + max_torque: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - dt: UnboundedContinuousTensorSpec( + dt: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - g: UnboundedContinuousTensorSpec( + g: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - m: UnboundedContinuousTensorSpec( + m: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - l: UnboundedContinuousTensorSpec( + l: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), - full_reward_spec: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + full_reward_spec: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( - low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), - space=DiscreteBox(n=2), + space=CategoricalBox(n=2), + device=cpu, dtype=torch.bool, domain=discrete), - terminated: DiscreteTensorSpec( + terminated: Categorical( shape=torch.Size([1]), - space=DiscreteBox(n=2), + space=CategoricalBox(n=2), + device=cpu, dtype=torch.bool, domain=discrete), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), - input_spec: CompositeSpec( - full_state_spec: CompositeSpec( - th: BoundedTensorSpec( + input_spec: Composite( + full_state_spec: Composite( + th: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - thdot: BoundedTensorSpec( + thdot: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - params: CompositeSpec( - max_speed: UnboundedContinuousTensorSpec( + params: Composite( + max_speed: UnboundedDiscrete( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)), + device=cpu, dtype=torch.int64, domain=discrete), - max_torque: UnboundedContinuousTensorSpec( + max_torque: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - dt: UnboundedContinuousTensorSpec( + dt: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - g: UnboundedContinuousTensorSpec( + g: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - m: UnboundedContinuousTensorSpec( + m: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - l: UnboundedContinuousTensorSpec( + l: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), - full_action_spec: CompositeSpec( - action: BoundedTensorSpec( + full_action_spec: Composite( + action: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( - low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])) """ @@ -240,14 +304,14 @@ def _reset(self, tensordict): def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] - self.observation_spec = CompositeSpec( - th=BoundedTensorSpec( + self.observation_spec = Composite( + th=Bounded( low=-torch.pi, high=torch.pi, shape=(), dtype=torch.float32, ), - thdot=BoundedTensorSpec( + thdot=Bounded( low=-td_params["params", "max_speed"], high=td_params["params", "max_speed"], shape=(), @@ -265,22 +329,22 @@ def _make_spec(self, td_params): self.state_spec = self.observation_spec.clone() # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-td_params["params", "max_torque"], high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + self.reward_spec = Unbounded(shape=(*td_params.shape, 1)) def make_composite_from_td(td): # custom function to convert a ``tensordict`` in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( + else Unbounded( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape ) for key, tensor in td.items() diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index 79ea3b2dfb6..6e5dee781e8 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -9,12 +9,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase @@ -39,28 +34,28 @@ class TicTacToeEnv(EnvBase): output entry). Specs: - CompositeSpec( - output_spec: CompositeSpec( - full_observation_spec: CompositeSpec( - board: DiscreteTensorSpec( + Composite( + output_spec: Composite( + full_observation_spec: Composite( + board: Categorical( shape=torch.Size([3, 3]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - turn: DiscreteTensorSpec( + turn: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - mask: DiscreteTensorSpec( + mask: Categorical( shape=torch.Size([9]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), - full_reward_spec: CompositeSpec( - player0: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + full_reward_spec: Composite( + player0: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -68,8 +63,8 @@ class TicTacToeEnv(EnvBase): dtype=torch.float32, domain=continuous), shape=torch.Size([])), - player1: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + player1: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -78,43 +73,43 @@ class TicTacToeEnv(EnvBase): domain=continuous), shape=torch.Size([])), shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), - terminated: DiscreteTensorSpec( + terminated: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), - truncated: DiscreteTensorSpec( + truncated: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), shape=torch.Size([])), - input_spec: CompositeSpec( - full_state_spec: CompositeSpec( - board: DiscreteTensorSpec( + input_spec: Composite( + full_state_spec: Composite( + board: Categorical( shape=torch.Size([3, 3]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - turn: DiscreteTensorSpec( + turn: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - mask: DiscreteTensorSpec( + mask: Categorical( shape=torch.Size([9]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), - full_action_spec: CompositeSpec( - action: DiscreteTensorSpec( + full_action_spec: Composite( + action: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=9), dtype=torch.int64, @@ -172,23 +167,21 @@ class TicTacToeEnv(EnvBase): def __init__(self, *, single_player: bool = False, device=None): super().__init__(device=device) self.single_player = single_player - self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec( + self.action_spec: Unbounded = Categorical( n=9, shape=(), device=device, ) - self.full_observation_spec: CompositeSpec = CompositeSpec( - board=UnboundedContinuousTensorSpec( - shape=(3, 3), dtype=torch.int, device=device - ), - turn=DiscreteTensorSpec( + self.full_observation_spec: Composite = Composite( + board=Unbounded(shape=(3, 3), dtype=torch.int, device=device), + turn=Categorical( 2, shape=(1,), dtype=torch.int, device=device, ), - mask=DiscreteTensorSpec( + mask=Categorical( 2, shape=(9,), dtype=torch.bool, @@ -196,22 +189,18 @@ def __init__(self, *, single_player: bool = False, device=None): ), device=device, ) - self.state_spec: CompositeSpec = self.observation_spec.clone() + self.state_spec: Composite = self.observation_spec.clone() - self.reward_spec: UnboundedContinuousTensorSpec = CompositeSpec( + self.reward_spec: Unbounded = Composite( { - ("player0", "reward"): UnboundedContinuousTensorSpec( - shape=(1,), device=device - ), - ("player1", "reward"): UnboundedContinuousTensorSpec( - shape=(1,), device=device - ), + ("player0", "reward"): Unbounded(shape=(1,), device=device), + ("player1", "reward"): Unbounded(shape=(1,), device=device), }, device=device, ) - self.full_done_spec: DiscreteTensorSpec = CompositeSpec( - done=DiscreteTensorSpec(2, shape=(1,), dtype=torch.bool, device=device), + self.full_done_spec: Categorical = Composite( + done=Categorical(2, shape=(1,), dtype=torch.bool, device=device), device=device, ) self.full_done_spec["terminated"] = self.full_done_spec["done"].clone() diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index c7935272c91..d2b6e0f23fa 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -15,11 +15,7 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper, EnvBase @@ -44,10 +40,10 @@ class default_info_dict_reader(BaseInfoDictReader): Args: keys (list of keys, optional): If provided, the list of keys to get from the info dictionary. Defaults to all keys. - spec (List[TensorSpec], Dict[str, TensorSpec] or CompositeSpec, optional): + spec (List[TensorSpec], Dict[str, TensorSpec] or Composite, optional): If a list of specs is provided, each spec will be matched to its - correspondent key to form a :class:`torchrl.data.CompositeSpec`. - If not provided, a composite spec with :class:`~torchrl.data.UnboundedContinuousTensorSpec` + correspondent key to form a :class:`torchrl.data.Composite`. + If not provided, a composite spec with :class:`~torchrl.data.Unbounded` specs will lazyly be created. ignore_private (bool, optional): If ``True``, private infos (starting with an underscore) will be ignored. Defaults to ``True``. @@ -72,10 +68,7 @@ class default_info_dict_reader(BaseInfoDictReader): def __init__( self, keys: List[str] | None = None, - spec: Sequence[TensorSpec] - | Dict[str, TensorSpec] - | CompositeSpec - | None = None, + spec: Sequence[TensorSpec] | Dict[str, TensorSpec] | Composite | None = None, ignore_private: bool = True, ): self.ignore_private = ignore_private @@ -87,19 +80,17 @@ def __init__( if spec is None and keys is None: _info_spec = None elif spec is None: - _info_spec = CompositeSpec( - {key: UnboundedContinuousTensorSpec(()) for key in keys}, shape=[] - ) - elif not isinstance(spec, CompositeSpec): + _info_spec = Composite({key: Unbounded(()) for key in keys}, shape=[]) + elif not isinstance(spec, Composite): if self.keys is not None and len(spec) != len(self.keys): raise ValueError( "If specifying specs for info keys with a sequence, the " "length of the sequence must match the number of keys" ) if isinstance(spec, dict): - _info_spec = CompositeSpec(spec, shape=[]) + _info_spec = Composite(spec, shape=[]) else: - _info_spec = CompositeSpec( + _info_spec = Composite( {key: spec for key, spec in zip(keys, spec)}, shape=[] ) else: @@ -121,7 +112,7 @@ def __call__( keys = [key for key in keys if not key.startswith("_")] self.keys = keys # create an info_spec only if there is none - info_spec = None if self.info_spec is not None else CompositeSpec() + info_spec = None if self.info_spec is not None else Composite() for key in keys: if key in info_dict: val = info_dict[key] @@ -130,7 +121,7 @@ def __call__( tensordict.set(key, val) if info_spec is not None: val = tensordict.get(key) - info_spec[key] = UnboundedContinuousTensorSpec( + info_spec[key] = Unbounded( val.shape, device=val.device, dtype=val.dtype ) elif self.info_spec is not None: diff --git a/torchrl/envs/libs/_gym_utils.py b/torchrl/envs/libs/_gym_utils.py index fb01f430fc1..6200987c5a8 100644 --- a/torchrl/envs/libs/_gym_utils.py +++ b/torchrl/envs/libs/_gym_utils.py @@ -12,7 +12,7 @@ from torch.utils._pytree import tree_map from torchrl._utils import implement_for -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import step_mdp, TransformedEnv from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform @@ -37,7 +37,7 @@ def __init__( ), ) self.observation_space = _torchrl_to_gym_spec_transform( - CompositeSpec( + Composite( { key: self.torchrl_env.full_observation_spec[key] for key in self._observation_keys diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index ac4cd71ddad..9542b8e71ff 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -11,11 +11,7 @@ from packaging import version from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.jax_utils import ( _extract_spec, @@ -55,8 +51,8 @@ class BraxWrapper(_EnvWrapper): Args: env (brax.envs.base.PipelineEnv): the environment to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -255,7 +251,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-1, high=1, shape=( @@ -264,15 +260,15 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 ), device=self.device, ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=[ *self.batch_size, 1, ], device=self.device, ) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + observation=Unbounded( shape=( *self.batch_size, env.observation_size, @@ -439,8 +435,8 @@ class BraxEnv(BraxWrapper): env_name (str): the environment name of the env to wrap. Must be part of :attr:`~.available_envs`. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 5558754de26..2ca62e106f6 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -16,13 +16,12 @@ from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Bounded, + Categorical, + Composite, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + Unbounded, ) from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict @@ -57,14 +56,10 @@ def _dmcontrol_to_torchrl_spec_transform( ) for k, item in spec.items() } - return CompositeSpec(**spec) + return Composite(**spec) elif isinstance(spec, dm_env.specs.DiscreteArray): # DiscreteArray is a type of BoundedArray so this block needs to go first - action_space_cls = ( - DiscreteTensorSpec - if categorical_discrete_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_discrete_encoding else OneHot if dtype is None: dtype = ( numpy_to_torch_dtype_dict[spec.dtype] @@ -78,7 +73,7 @@ def _dmcontrol_to_torchrl_spec_transform( shape = spec.shape if not len(shape): shape = torch.Size([1]) - return BoundedTensorSpec( + return Bounded( shape=shape, low=spec.minimum, high=spec.maximum, @@ -92,11 +87,9 @@ def _dmcontrol_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): - return UnboundedContinuousTensorSpec( - shape=shape, dtype=dtype, device=device - ) + return Unbounded(shape=shape, dtype=dtype, device=device) else: - return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) + return Unbounded(shape=shape, dtype=dtype, device=device) else: raise NotImplementedError(type(spec)) @@ -254,10 +247,10 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 reward_spec.shape = torch.Size([1]) self.reward_spec = reward_spec # populate default done spec - done_spec = DiscreteTensorSpec( + done_spec = Categorical( n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device ) - self.done_spec = CompositeSpec( + self.done_spec = Composite( done=done_spec.clone(), truncated=done_spec.clone(), terminated=done_spec.clone(), diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index a029a0beb5b..599645dfdfc 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -13,12 +13,7 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.utils import _classproperty @@ -35,8 +30,8 @@ class MultiThreadedEnvWrapper(_EnvWrapper): Args: env (envpool.python.envpool.EnvPoolMixin): the envpool to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -161,7 +156,7 @@ def _get_action_spec(self) -> TensorSpec: return action_spec def _get_output_spec(self) -> TensorSpec: - return CompositeSpec( + return Composite( full_observation_spec=self._get_observation_spec(), full_reward_spec=self._get_reward_spec(), full_done_spec=self._get_done_spec(), @@ -180,9 +175,9 @@ def _get_observation_spec(self) -> TensorSpec: categorical_action_encoding=True, ) observation_spec = self._add_shape_to_spec(observation_spec) - if isinstance(observation_spec, CompositeSpec): + if isinstance(observation_spec, Composite): return observation_spec - return CompositeSpec( + return Composite( observation=observation_spec, shape=(self.num_workers,), device=self.device, @@ -192,19 +187,19 @@ def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec: return spec.expand((self.num_workers, *spec.shape)) def _get_reward_spec(self) -> TensorSpec: - return UnboundedContinuousTensorSpec( + return Unbounded( device=self.device, shape=self.batch_size, ) def _get_done_spec(self) -> TensorSpec: - spec = DiscreteTensorSpec( + spec = Categorical( 2, device=self.device, shape=self.batch_size, dtype=torch.bool, ) - return CompositeSpec( + return Composite( done=spec, truncated=spec.clone(), terminated=spec.clone(), @@ -335,8 +330,8 @@ class MultiThreadedEnv(MultiThreadedEnvWrapper): create_env_kwargs (Dict[str, Any], optional): kwargs to be passed to envpool environment constructor. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default for these versions), the environment checker won't be run. diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 9195929e31d..8431d155ee2 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -23,16 +23,15 @@ from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( _minmax_dtype, - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Bounded, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict from torchrl.envs.batched_envs import CloudpickleWrapper @@ -259,11 +258,7 @@ def _gym_to_torchrl_spec_transform( ) return result if isinstance(spec, gym_spaces.discrete.Discrete): - action_space_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_action_encoding else OneHot dtype = ( numpy_to_torch_dtype_dict[spec.dtype] if categorical_action_encoding @@ -271,7 +266,7 @@ def _gym_to_torchrl_spec_transform( ) return action_space_cls(spec.n, device=device, dtype=dtype) elif isinstance(spec, gym_spaces.multi_binary.MultiBinary): - return BinaryDiscreteTensorSpec( + return Binary( spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] ) # a spec type cannot be a string, so we're sure that versions of gym that don't have Sequence will just skip through this @@ -300,11 +295,9 @@ def _gym_to_torchrl_spec_transform( ) return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + MultiCategorical(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec( - spec.nvec, device=device, dtype=dtype - ) + else MultiOneHot(spec.nvec, device=device, dtype=dtype) ) return torch.stack( @@ -337,9 +330,9 @@ def _gym_to_torchrl_spec_transform( and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all() ) return ( - UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype) + Unbounded(shape, device=device, dtype=dtype) if is_unbounded - else BoundedTensorSpec( + else Bounded( low, high, shape, @@ -368,7 +361,7 @@ def _gym_to_torchrl_spec_transform( remap_state_to_observation=remap_state_to_observation, ) # the batch-size must be set later - return CompositeSpec(spec_out, device=device) + return Composite(spec_out, device=device) elif isinstance(spec, gym_spaces.dict.Dict): return _gym_to_torchrl_spec_transform( spec.spaces, @@ -445,19 +438,19 @@ def _torchrl_to_gym_spec_transform( return gym_spaces.Tuple( tuple(_torchrl_to_gym_spec_transform(spec) for spec in spec.unbind(0)) ) - if isinstance(spec, MultiDiscreteTensorSpec): + if isinstance(spec, MultiCategorical): return _multidiscrete_convert(gym_spaces, spec) - if isinstance(spec, MultiOneHotDiscreteTensorSpec): + if isinstance(spec, MultiOneHot): return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec) - if isinstance(spec, BinaryDiscreteTensorSpec): + if isinstance(spec, Binary): return gym_spaces.multi_binary.MultiBinary(spec.shape[-1]) - if isinstance(spec, DiscreteTensorSpec): + if isinstance(spec, Categorical): return gym_spaces.discrete.Discrete( spec.n ) # dtype=torch_to_numpy_dtype_dict[spec.dtype]) - if isinstance(spec, OneHotDiscreteTensorSpec): + if isinstance(spec, OneHot): return gym_spaces.discrete.Discrete(spec.n) - if isinstance(spec, UnboundedContinuousTensorSpec): + if isinstance(spec, Unbounded): minval, maxval = _minmax_dtype(spec.dtype) return gym_spaces.Box( low=minval, @@ -465,7 +458,7 @@ def _torchrl_to_gym_spec_transform( shape=shape, dtype=torch_to_numpy_dtype_dict[spec.dtype], ) - if isinstance(spec, UnboundedDiscreteTensorSpec): + if isinstance(spec, Unbounded): minval, maxval = _minmax_dtype(spec.dtype) return gym_spaces.Box( low=minval, @@ -473,9 +466,9 @@ def _torchrl_to_gym_spec_transform( shape=shape, dtype=torch_to_numpy_dtype_dict[spec.dtype], ) - if isinstance(spec, BoundedTensorSpec): + if isinstance(spec, Bounded): return _box_convert(spec, gym_spaces, shape) - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): # remove batch size while spec.shape: spec = spec[0] @@ -624,8 +617,8 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): or :class:`gym.VectorEnv`) are supported and the environment batch-size will reflect the number of environments executed in parallel. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -865,10 +858,7 @@ def _build_env( def read_action(self, action): action = super().read_action(action) - if ( - isinstance(self.action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec)) - and action.size == 1 - ): + if isinstance(self.action_spec, (OneHot, Categorical)) and action.size == 1: # some envs require an integer for indexing action = int(action) return action @@ -1012,13 +1002,13 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 device=self.device, categorical_action_encoding=self._categorical_action_encoding, ) - if not isinstance(observation_spec, CompositeSpec): + if not isinstance(observation_spec, Composite): if self.from_pixels: - observation_spec = CompositeSpec( + observation_spec = Composite( pixels=observation_spec, shape=cur_batch_size ) else: - observation_spec = CompositeSpec( + observation_spec = Composite( observation=observation_spec, shape=cur_batch_size ) elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size: @@ -1032,7 +1022,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 categorical_action_encoding=self._categorical_action_encoding, ) else: - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( shape=[1], device=self.device, ) @@ -1053,15 +1043,15 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 @implement_for("gym", None, "0.26") def _make_done_spec(self): # noqa: F811 - return CompositeSpec( + return Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), }, @@ -1070,15 +1060,15 @@ def _make_done_spec(self): # noqa: F811 @implement_for("gym", "0.26", None) def _make_done_spec(self): # noqa: F811 - return CompositeSpec( + return Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), }, @@ -1087,15 +1077,15 @@ def _make_done_spec(self): # noqa: F811 @implement_for("gymnasium", "0.27", None) def _make_done_spec(self): # noqa: F811 - return CompositeSpec( + return Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), }, @@ -1250,8 +1240,8 @@ class GymEnv(GymWrapper): Args: env_name (str): the environment id registered in `gym.registry`. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -1567,7 +1557,7 @@ class terminal_obs_reader(default_info_dict_reader): replaced. Args: - observation_spec (CompositeSpec): The observation spec of the gym env. + observation_spec (Composite): The observation spec of the gym env. backend (str, optional): the backend of the env. One of `"sb3"` for stable-baselines3 or `"gym"` for gym/gymnasium. @@ -1585,7 +1575,7 @@ class terminal_obs_reader(default_info_dict_reader): "gym": "final_info", } - def __init__(self, observation_spec: CompositeSpec, backend, name="final"): + def __init__(self, observation_spec: Composite, backend, name="final"): super().__init__() self.name = name self._obs_spec = observation_spec.clone() diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 53752147acc..4180c42b2dc 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -54,8 +54,8 @@ class HabitatEnv(GymEnv): Args: env_name (str): The environment to execute. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 4c56bea304a..fb37639ad37 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -14,7 +14,7 @@ import torch from tensordict import TensorDictBase -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.utils import _classproperty, make_composite_from_td @@ -59,7 +59,7 @@ def __init__( def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 super()._make_specs(env, batch_size=self.batch_size) - self.full_done_spec = CompositeSpec( + self.full_done_spec = Composite( { key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True) diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index d1d1094a264..052f538f0c4 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -13,12 +13,7 @@ # from jax import dlpack as jax_dlpack, numpy as jnp from tensordict import make_tensordict, TensorDictBase from torch.utils import dlpack as torch_dlpack -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import numpy_to_torch_dtype_dict _has_jax = importlib.util.find_spec("jax") is not None @@ -155,15 +150,11 @@ def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> Tensor if key in ("reward", "done"): shape = (*shape, 1) if data.dtype in (torch.float, torch.double, torch.half): - return UnboundedContinuousTensorSpec( - shape=shape, dtype=data.dtype, device=data.device - ) + return Unbounded(shape=shape, dtype=data.dtype, device=data.device) else: - return UnboundedDiscreteTensorSpec( - shape=shape, dtype=data.dtype, device=data.device - ) + return Unbounded(shape=shape, dtype=data.dtype, device=data.device) elif isinstance(data, TensorDictBase): - return CompositeSpec( + return Composite( {key: _extract_spec(value, key=key) for key, value in data.items()} ) else: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 071c8f7f56c..dbbc980e8cc 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -18,16 +18,15 @@ _has_jumanji = importlib.util.find_spec("jumanji") is not None from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, + Bounded, + Categorical, + Composite, DEVICE_TYPING, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + MultiCategorical, + MultiOneHot, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs.gym_like import GymLikeEnv @@ -59,19 +58,13 @@ def _jumanji_to_torchrl_spec_transform( import jumanji if isinstance(spec, jumanji.specs.DiscreteArray): - action_space_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_action_encoding else OneHot if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return action_space_cls(spec.num_values, dtype=dtype, device=device) if isinstance(spec, jumanji.specs.MultiDiscreteArray): action_space_cls = ( - MultiDiscreteTensorSpec - if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec + MultiCategorical if categorical_action_encoding else MultiOneHot ) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] @@ -82,7 +75,7 @@ def _jumanji_to_torchrl_spec_transform( shape = spec.shape if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] - return BoundedTensorSpec( + return Bounded( shape=shape, low=np.asarray(spec.minimum), high=np.asarray(spec.maximum), @@ -94,11 +87,9 @@ def _jumanji_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): - return UnboundedContinuousTensorSpec( - shape=shape, dtype=dtype, device=device - ) + return Unbounded(shape=shape, dtype=dtype, device=device) else: - return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) + return Unbounded(shape=shape, dtype=dtype, device=device) elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"): new_spec = {} for key, value in spec.__dict__.items(): @@ -110,7 +101,7 @@ def _jumanji_to_torchrl_spec_transform( new_spec[key] = _jumanji_to_torchrl_spec_transform( value, dtype, device, categorical_action_encoding ) - return CompositeSpec(**new_spec) + return Composite(**new_spec) else: raise TypeError(f"Unsupported spec type {type(spec)}") @@ -140,8 +131,8 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): Args: env (jumanji.env.Environment): the env to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -433,9 +424,9 @@ def _make_observation_spec(self, env) -> TensorSpec: spec = env.observation_spec new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device) if isinstance(spec, jumanji.specs.Array): - return CompositeSpec(observation=new_spec).expand(self.batch_size) + return Composite(observation=new_spec).expand(self.batch_size) elif isinstance(spec, jumanji.specs.Spec): - return CompositeSpec(**{k: v for k, v in new_spec.items()}).expand( + return Composite(**{k: v for k, v in new_spec.items()}).expand( self.batch_size ) else: @@ -681,8 +672,8 @@ class JumanjiEnv(JumanjiWrapper): Args: env_name (str): the name of the environment to wrap. Must be part of :attr:`~.available_envs`. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 446b3dac292..b8e52031a23 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -12,7 +12,7 @@ from tensordict import TensorDict, TensorDictBase -from torchrl.data import CompositeSpec, DiscreteTensorSpec, TensorSpec +from torchrl.data import Categorical, Composite, TensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType @@ -246,9 +246,9 @@ def _make_specs( } self._make_group_map() - action_spec = CompositeSpec() - observation_spec = CompositeSpec() - reward_spec = CompositeSpec() + action_spec = Composite() + observation_spec = Composite() + reward_spec = Composite() for group in self.group_map.keys(): ( @@ -266,11 +266,9 @@ def _make_specs( reward_spec[group] = group_reward_spec observation_spec.update(torchrl_state_spec) - self.done_spec = CompositeSpec( + self.done_spec = Composite( { - "done": DiscreteTensorSpec( - n=2, shape=torch.Size((1,)), dtype=torch.bool - ), + "done": Categorical(n=2, shape=torch.Size((1,)), dtype=torch.bool), }, ) self.action_spec = action_spec @@ -292,7 +290,7 @@ def _make_group_specs( for agent_name in self.group_map[group]: agent_index = self.agent_names_to_indices_map[agent_name] action_specs.append( - CompositeSpec( + Composite( { "action": torchrl_agent_act_specs[ agent_index @@ -301,7 +299,7 @@ def _make_group_specs( ) ) observation_specs.append( - CompositeSpec( + Composite( { "observation": torchrl_agent_obs_specs[ agent_index @@ -310,7 +308,7 @@ def _make_group_specs( ) ) reward_specs.append( - CompositeSpec({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,) + Composite({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,) ) # Create multi-agent specs diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 7ac318e03cb..55b246bd902 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -8,12 +8,7 @@ from tensordict import TensorDict, TensorDictBase from torchrl.data.replay_buffers import SamplerWithoutReplacement -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase from torchrl.envs.transforms import Compose, DoubleToFloat, RenameTransform from torchrl.envs.utils import _classproperty @@ -24,17 +19,13 @@ def _make_composite_from_td(td): # custom funtion to convert a tensordict in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: _make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) + else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape) if tensor.dtype in (torch.float16, torch.float32, torch.float64) - else UnboundedDiscreteTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) + else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape) for key, tensor in td.items() }, shape=td.shape, @@ -115,10 +106,10 @@ def __init__(self, dataset_name, device="cpu", batch_size=None): .reshape(self.batch_size) .exclude("index") ) - self.action_spec = DiscreteTensorSpec( + self.action_spec = Categorical( self._data.max_outcome_val + 1, shape=self.batch_size, device=self.device ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*self.batch_size, 1)) + self.reward_spec = Unbounded(shape=(*self.batch_size, 1)) def _reset(self, tensordict): data = self._data.sample() diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index eb94a27cbba..e34ca4600a7 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -13,12 +13,7 @@ import torch from tensordict import TensorDictBase -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType @@ -308,24 +303,24 @@ def _make_specs( check_marl_grouping(self.group_map, self.possible_agents) self.has_action_mask = {group: False for group in self.group_map.keys()} - action_spec = CompositeSpec() - observation_spec = CompositeSpec() - reward_spec = CompositeSpec() - done_spec = CompositeSpec( + action_spec = Composite() + observation_spec = Composite() + reward_spec = Composite() + done_spec = Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, @@ -356,7 +351,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): observation_specs = [] for agent in agent_names: action_specs.append( - CompositeSpec( + Composite( { "action": _gym_to_torchrl_spec_transform( self.action_space(agent), @@ -368,7 +363,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): ) ) observation_specs.append( - CompositeSpec( + Composite( { "observation": _gym_to_torchrl_spec_transform( self.observation_space(agent), @@ -386,12 +381,12 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): # We uniform this by removing it from both places and optionally set it in a standard location. group_observation_inner_spec = group_observation_spec["observation"] if ( - isinstance(group_observation_inner_spec, CompositeSpec) + isinstance(group_observation_inner_spec, Composite) and "action_mask" in group_observation_inner_spec.keys() ): self.has_action_mask[group_name] = True del group_observation_inner_spec["action_mask"] - group_observation_spec["action_mask"] = DiscreteTensorSpec( + group_observation_spec["action_mask"] = Categorical( n=2, shape=group_action_spec["action"].shape if not self.categorical_actions @@ -404,16 +399,16 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): ) if self.use_mask: - group_observation_spec["mask"] = DiscreteTensorSpec( + group_observation_spec["mask"] = Categorical( n=2, shape=torch.Size((n_agents,)), dtype=torch.bool, device=self.device, ) - group_reward_spec = CompositeSpec( + group_reward_spec = Composite( { - "reward": UnboundedContinuousTensorSpec( + "reward": Unbounded( shape=torch.Size((n_agents, 1)), device=self.device, dtype=torch.float32, @@ -421,21 +416,21 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): }, shape=torch.Size((n_agents,)), ) - group_done_spec = CompositeSpec( + group_done_spec = Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=torch.Size((n_agents, 1)), dtype=torch.bool, device=self.device, ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( n=2, shape=torch.Size((n_agents, 1)), dtype=torch.bool, device=self.device, ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( n=2, shape=torch.Size((n_agents, 1)), dtype=torch.bool, @@ -473,11 +468,11 @@ def _init_env(self): info_specs = [] for agent in agents: info_specs.append( - CompositeSpec( + Composite( { - "info": CompositeSpec( + "info": Composite( { - key: UnboundedContinuousTensorSpec( + key: Unbounded( shape=torch.as_tensor(value).shape, device=self.device, ) @@ -495,7 +490,7 @@ def _init_env(self): group_action_spec = self.input_spec[ "full_action_spec", group, "action" ] - self.observation_spec[group]["action_mask"] = DiscreteTensorSpec( + self.observation_spec[group]["action_mask"] = Categorical( n=2, shape=group_action_spec.shape if not self.categorical_actions @@ -518,7 +513,7 @@ def _init_env(self): ) except AttributeError: state_example = torch.as_tensor(self.state(), device=self.device) - state_spec = UnboundedContinuousTensorSpec( + state_spec = Unbounded( shape=state_example.shape, dtype=state_example.dtype, device=self.device, @@ -809,9 +804,7 @@ def _update_action_mask(self, td, observation_dict, info_dict): del agent_info["action_mask"] group_action_spec = self.input_spec["full_action_spec", group, "action"] - if isinstance( - group_action_spec, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) - ): + if isinstance(group_action_spec, (Categorical, OneHot)): # We update the mask for available actions group_action_spec.update_mask(group_mask.clone()) diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 5e5c8f52393..30d9c644ced 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -12,7 +12,7 @@ import numpy as np import torch from tensordict import TensorDict -from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Unbounded from torchrl.envs.libs.gym import ( _AsyncMeta, _gym_to_torchrl_spec_transform, @@ -80,8 +80,8 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): Args: env_name (str): the environment name to build. Must be one of :attr:`.available_envs` categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -305,7 +305,7 @@ def get_obs(): ) self.observation_spec = observation_spec - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=(1,), device=self.device, ) # default diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index d460eb38f1e..67e71da0d5a 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -10,13 +10,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Categorical, Composite, OneHot, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.utils import _classproperty, ACTION_MASK_ERROR @@ -224,11 +218,11 @@ def _build_env( def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821 self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=torch.Size((1,)), device=self.device, ) - self.done_spec = DiscreteTensorSpec( + self.done_spec = Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, @@ -241,54 +235,50 @@ def _init_env(self) -> None: self._env.reset() self._update_action_mask() - def _make_action_spec(self) -> CompositeSpec: + def _make_action_spec(self) -> Composite: if self.categorical_actions: - action_spec = DiscreteTensorSpec( + action_spec = Categorical( self.n_actions, shape=torch.Size((self.n_agents,)), device=self.device, dtype=torch.long, ) else: - action_spec = OneHotDiscreteTensorSpec( + action_spec = OneHot( self.n_actions, shape=torch.Size((self.n_agents, self.n_actions)), device=self.device, dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( { - "agents": CompositeSpec( + "agents": Composite( {"action": action_spec}, shape=torch.Size((self.n_agents,)) ) } ) return spec - def _make_observation_spec(self) -> CompositeSpec: - obs_spec = BoundedTensorSpec( + def _make_observation_spec(self) -> Composite: + obs_spec = Bounded( low=-1.0, high=1.0, shape=torch.Size([self.n_agents, self.get_obs_size()]), device=self.device, dtype=torch.float32, ) - info_spec = CompositeSpec( + info_spec = Composite( { - "battle_won": DiscreteTensorSpec( - 2, dtype=torch.bool, device=self.device - ), - "episode_limit": DiscreteTensorSpec( - 2, dtype=torch.bool, device=self.device - ), - "dead_allies": BoundedTensorSpec( + "battle_won": Categorical(2, dtype=torch.bool, device=self.device), + "episode_limit": Categorical(2, dtype=torch.bool, device=self.device), + "dead_allies": Bounded( low=0, high=self.n_agents, dtype=torch.long, device=self.device, shape=(), ), - "dead_enemies": BoundedTensorSpec( + "dead_enemies": Bounded( low=0, high=self.n_enemies, dtype=torch.long, @@ -297,19 +287,19 @@ def _make_observation_spec(self) -> CompositeSpec: ), } ) - mask_spec = DiscreteTensorSpec( + mask_spec = Categorical( 2, torch.Size([self.n_agents, self.n_actions]), device=self.device, dtype=torch.bool, ) - spec = CompositeSpec( + spec = Composite( { - "agents": CompositeSpec( + "agents": Composite( {"observation": obs_spec, "action_mask": mask_spec}, shape=torch.Size((self.n_agents,)), ), - "state": BoundedTensorSpec( + "state": Bounded( low=-1.0, high=1.0, shape=torch.Size((self.get_state_size(),)), diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 9751e84a3ac..5811580826d 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -12,16 +12,16 @@ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, + Bounded, + Categorical, + Composite, DEVICE_TYPING, - DiscreteTensorSpec, - LazyStackedCompositeSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + MultiCategorical, + MultiOneHot, + OneHot, + StackedComposite, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs.common import _EnvWrapper, EnvBase @@ -57,11 +57,7 @@ def _vmas_to_torchrl_spec_transform( ) -> TensorSpec: gym_spaces = gym_backend("spaces") if isinstance(spec, gym_spaces.discrete.Discrete): - action_space_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_action_encoding else OneHot dtype = ( numpy_to_torch_dtype_dict[spec.dtype] if categorical_action_encoding @@ -75,9 +71,9 @@ def _vmas_to_torchrl_spec_transform( else torch.long ) return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + MultiCategorical(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + else MultiOneHot(spec.nvec, device=device, dtype=dtype) ) elif isinstance(spec, gym_spaces.Box): shape = spec.shape @@ -88,9 +84,9 @@ def _vmas_to_torchrl_spec_transform( high = torch.tensor(spec.high, device=device, dtype=dtype) is_unbounded = low.isinf().all() and high.isinf().all() return ( - UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype) + Unbounded(shape, device=device, dtype=dtype) if is_unbounded - else BoundedTensorSpec( + else Bounded( low, high, shape, @@ -322,9 +318,9 @@ def _make_specs( self.group_map = self.group_map.get_group_map(self.agent_names) check_marl_grouping(self.group_map, self.agent_names) - self.unbatched_action_spec = CompositeSpec(device=self.device) - self.unbatched_observation_spec = CompositeSpec(device=self.device) - self.unbatched_reward_spec = CompositeSpec(device=self.device) + self.unbatched_action_spec = Composite(device=self.device) + self.unbatched_observation_spec = Composite(device=self.device) + self.unbatched_reward_spec = Composite(device=self.device) self.het_specs = False self.het_specs_map = {} @@ -341,14 +337,14 @@ def _make_specs( if group_info_spec is not None: self.unbatched_observation_spec[(group, "info")] = group_info_spec group_het_specs = isinstance( - group_observation_spec, LazyStackedCompositeSpec - ) or isinstance(group_action_spec, LazyStackedCompositeSpec) + group_observation_spec, StackedComposite + ) or isinstance(group_action_spec, StackedComposite) self.het_specs_map[group] = group_het_specs self.het_specs = self.het_specs or group_het_specs - self.unbatched_done_spec = CompositeSpec( + self.unbatched_done_spec = Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, @@ -380,7 +376,7 @@ def _make_unbatched_group_specs(self, group: str): agent_index = self.agent_names_to_indices_map[agent_name] agent = self.agents[agent_index] action_specs.append( - CompositeSpec( + Composite( { "action": _vmas_to_torchrl_spec_transform( self.action_space[agent_index], @@ -391,7 +387,7 @@ def _make_unbatched_group_specs(self, group: str): ) ) observation_specs.append( - CompositeSpec( + Composite( { "observation": _vmas_to_torchrl_spec_transform( self.observation_space[agent_index], @@ -402,9 +398,9 @@ def _make_unbatched_group_specs(self, group: str): ) ) reward_specs.append( - CompositeSpec( + Composite( { - "reward": UnboundedContinuousTensorSpec( + "reward": Unbounded( shape=torch.Size((1,)), device=self.device, ) # shape = (1,) @@ -414,9 +410,9 @@ def _make_unbatched_group_specs(self, group: str): agent_info = self.scenario.info(agent) if len(agent_info): info_specs.append( - CompositeSpec( + Composite( { - key: UnboundedContinuousTensorSpec( + key: Unbounded( shape=_selective_unsqueeze( value, batch_size=self.batch_size ).shape[1:], diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index f6b3f97cd4a..2a3c0198f9c 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -27,18 +27,18 @@ class ModelBasedEnvBase(EnvBase): Example: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) - ... self.observation_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.observation_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.state_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)), + ... self.state_spec = Composite( + ... hidden_observation=Unbounded((4,)), ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((1,)) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) + ... self.action_spec = Unbounded((1,)) + ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict({}, @@ -84,10 +84,10 @@ class ModelBasedEnvBase(EnvBase): Properties: - - observation_spec (CompositeSpec): sampling spec of the observations; + - observation_spec (Composite): sampling spec of the observations; - action_spec (TensorSpec): sampling spec of the actions; - reward_spec (TensorSpec): sampling spec of the rewards; - - input_spec (CompositeSpec): sampling spec of the inputs; + - input_spec (Composite): sampling spec of the inputs; - batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. - device (torch.device): device where the env input and output are expected to live diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 5609861c75f..f5636f76c5a 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -9,7 +9,7 @@ from tensordict import TensorDict from tensordict.nn import TensorDictModule -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase @@ -39,7 +39,7 @@ def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" super().set_specs_from_env(env) self.action_spec = self.action_spec.to(self.device) - self.state_spec = CompositeSpec( + self.state_spec = Composite( state=self.observation_spec["state"], belief=self.observation_spec["belief"], shape=env.batch_size, diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index 35f122b770a..b3ac334a5d8 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -10,7 +10,7 @@ import torchrl.objectives.common from tensordict import TensorDictBase from tensordict.utils import expand_as_right, NestedKey -from torchrl.data.tensor_specs import UnboundedDiscreteTensorSpec +from torchrl.data.tensor_specs import Unbounded from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform @@ -179,7 +179,7 @@ def _reset(self, tensordict, tensordict_reset): def transform_observation_spec(self, observation_spec): full_done_spec = self.parent.output_spec["full_done_spec"] observation_spec[self.eol_key] = full_done_spec[self.done_key].clone() - observation_spec[self.lives_key] = UnboundedDiscreteTensorSpec( + observation_spec[self.lives_key] = Unbounded( self.parent.batch_size, device=self.parent.device, dtype=torch.int64, diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 546321d5815..d4505a4d240 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -11,11 +11,7 @@ from torch.hub import load_state_dict_from_url from torch.nn import Identity -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( CatTensors, @@ -103,8 +99,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: return out def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError("_R3MNet can only infer CompositeSpec") + if not isinstance(observation_spec, Composite): + raise ValueError("_R3MNet can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device @@ -116,7 +112,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, self.outdim]), device=device ) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 33874393038..b41a290d3f7 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -9,7 +9,7 @@ from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import is_seq_of_nested_key from torch import nn -from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Composite, Unbounded from torchrl.envs.transforms.transforms import Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param @@ -186,7 +186,7 @@ def _step( forward = _call - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: output_spec = super().transform_output_spec(output_spec) # todo: here we'll need to use the reward_key once it's implemented # parent = self.parent @@ -195,17 +195,17 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: if in_key == "reward" and out_key == "reward": parent = self.parent - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( device=output_spec.device, shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) - output_spec["full_reward_spec"] = CompositeSpec( + output_spec["full_reward_spec"] = Composite( {parent.reward_key: reward_spec}, shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": parent = self.parent - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( device=output_spec.device, shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) @@ -214,7 +214,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: observation_spec[out_key] = reward_spec else: observation_spec = output_spec["full_observation_spec"] - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( device=output_spec.device, shape=observation_spec[in_key].shape ) # then we need to populate the output keys diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 255af86a61e..8859af2f9cd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -48,16 +48,16 @@ from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, + Binary, + Bounded, + Categorical, + Composite, ContinuousBox, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + MultiCategorical, + MultiOneHot, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict from torchrl.envs.transforms import functional as F @@ -80,14 +80,14 @@ def _apply_to_composite(function): @wraps(function) def new_fun(self, observation_spec): - if isinstance(observation_spec, CompositeSpec): + if isinstance(observation_spec, Composite): _specs = observation_spec._specs in_keys = self.in_keys out_keys = self.out_keys for in_key, out_key in zip(in_keys, out_keys): if in_key in observation_spec.keys(True, True): _specs[out_key] = function(self, observation_spec[in_key].clone()) - return CompositeSpec( + return Composite( _specs, shape=observation_spec.shape, device=observation_spec.device ) else: @@ -109,7 +109,7 @@ def new_fun(self, input_spec): action_spec = input_spec["full_action_spec"].clone() state_spec = input_spec["full_state_spec"] if state_spec is None: - state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) + state_spec = Composite(shape=input_spec.shape, device=input_spec.device) else: state_spec = state_spec.clone() in_keys_inv = self.in_keys_inv @@ -122,7 +122,7 @@ def new_fun(self, input_spec): action_spec[out_key] = function(self, action_spec[in_key].clone()) elif in_key in state_spec.keys(True, True): state_spec[out_key] = function(self, state_spec[in_key].clone()) - return CompositeSpec( + return Composite( full_state_spec=state_spec, full_action_spec=action_spec, shape=input_spec.shape, @@ -360,7 +360,7 @@ def transform_env_batch_size(self, batch_size: torch.Size): """Transforms the batch-size of the parent env.""" return batch_size - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: """Transforms the output spec such that the resulting spec matches transform mapping. This method should generally be left untouched. Changes should be implemented using @@ -831,7 +831,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset): return tensordict_reset def _complete_done( - cls, done_spec: CompositeSpec, data: TensorDictBase + cls, done_spec: Composite, data: TensorDictBase ) -> TensorDictBase: # This step has already been completed. We assume the transform module do their job correctly. return data @@ -1465,7 +1465,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - return BoundedTensorSpec( + return Bounded( shape=observation_spec.shape, device=observation_spec.device, dtype=observation_spec.dtype, @@ -1477,7 +1477,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: for key in self.in_keys: if key in self.parent.reward_keys: spec = self.parent.output_spec["full_reward_spec"][key] - self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec( + self.parent.output_spec["full_reward_spec"][key] = Bounded( shape=spec.shape, device=spec.device, dtype=spec.dtype, @@ -1685,7 +1685,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec target = self.parent.full_done_spec[in_key] else: raise RuntimeError(f"in_key {in_key} not found in output_spec.") - target_return_spec = UnboundedContinuousTensorSpec( + target_return_spec = Unbounded( shape=target.shape, dtype=target.dtype, device=target.device, @@ -1744,8 +1744,8 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - if isinstance(reward_spec, UnboundedContinuousTensorSpec): - return BoundedTensorSpec( + if isinstance(reward_spec, Unbounded): + return Bounded( self.clamp_min, self.clamp_max, shape=reward_spec.shape, @@ -1798,7 +1798,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - return BinaryDiscreteTensorSpec( + return Binary( n=1, device=reward_spec.device, shape=reward_spec.shape, @@ -3321,7 +3321,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - if isinstance(reward_spec, UnboundedContinuousTensorSpec): + if isinstance(reward_spec, Unbounded): return reward_spec else: raise NotImplementedError( @@ -3424,10 +3424,10 @@ class DTypeCastTransform(Transform): >>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() - ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) - ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) - ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64)) + ... self.action_spec = Unbounded((), dtype=torch.float64) + ... self.reward_spec = Unbounded((1,), dtype=torch.float64) + ... self.done_spec = Unbounded((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): @@ -3640,7 +3640,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: return state.to(self.dtype_in) def _transform_spec(self, spec: TensorSpec) -> None: - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): for key in spec: self._transform_spec(spec[key]) else: @@ -3685,7 +3685,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: raise RuntimeError return input_spec - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self.in_keys is None: raise NotImplementedError( f"Calling transform_reward_spec without a parent environment isn't supported yet for {type(self)}." @@ -3794,10 +3794,10 @@ class DoubleToFloat(DTypeCastTransform): >>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() - ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) - ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) - ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64)) + ... self.action_spec = Unbounded((), dtype=torch.float64) + ... self.reward_spec = Unbounded((1,), dtype=torch.float64) + ... self.done_spec = Unbounded((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): @@ -4010,13 +4010,13 @@ def _sync_orig_device(self): return self._sync_orig_device return sync_func - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + def transform_input_spec(self, input_spec: Composite) -> Composite: if self._map_env_device: return input_spec.to(self.device) else: return super().transform_input_spec(input_spec) - def transform_action_spec(self, full_action_spec: CompositeSpec) -> CompositeSpec: + def transform_action_spec(self, full_action_spec: Composite) -> Composite: full_action_spec = full_action_spec.clear_device_() for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): if in_key not in full_action_spec.keys(True, True): @@ -4024,7 +4024,7 @@ def transform_action_spec(self, full_action_spec: CompositeSpec) -> CompositeSpe full_action_spec[out_key] = full_action_spec[in_key].to(self.device) return full_action_spec - def transform_state_spec(self, full_state_spec: CompositeSpec) -> CompositeSpec: + def transform_state_spec(self, full_state_spec: Composite) -> Composite: full_state_spec = full_state_spec.clear_device_() for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): if in_key not in full_state_spec.keys(True, True): @@ -4032,15 +4032,13 @@ def transform_state_spec(self, full_state_spec: CompositeSpec) -> CompositeSpec: full_state_spec[out_key] = full_state_spec[in_key].to(self.device) return full_state_spec - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self._map_env_device: return output_spec.to(self.device) else: return super().transform_output_spec(output_spec) - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: + def transform_observation_spec(self, observation_spec: Composite) -> Composite: observation_spec = observation_spec.clear_device_() for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key not in observation_spec.keys(True, True): @@ -4048,7 +4046,7 @@ def transform_observation_spec( observation_spec[out_key] = observation_spec[in_key].to(self.device) return observation_spec - def transform_done_spec(self, full_done_spec: CompositeSpec) -> CompositeSpec: + def transform_done_spec(self, full_done_spec: Composite) -> Composite: full_done_spec = full_done_spec.clear_device_() for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key not in full_done_spec.keys(True, True): @@ -4056,7 +4054,7 @@ def transform_done_spec(self, full_done_spec: CompositeSpec) -> CompositeSpec: full_done_spec[out_key] = full_done_spec[in_key].to(self.device) return full_done_spec - def transform_reward_spec(self, full_reward_spec: CompositeSpec) -> CompositeSpec: + def transform_reward_spec(self, full_reward_spec: Composite) -> Composite: full_reward_spec = full_reward_spec.clear_device_() for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key not in full_reward_spec.keys(True, True): @@ -4215,13 +4213,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec self._initialized = True # check that all keys are in observation_spec - if len(self.in_keys) > 1 and not isinstance(observation_spec, CompositeSpec): + if len(self.in_keys) > 1 and not isinstance(observation_spec, Composite): raise ValueError( "CatTensor cannot infer the output observation spec as there are multiple input keys but " "only one observation_spec." ) - if isinstance(observation_spec, CompositeSpec) and len( + if isinstance(observation_spec, Composite) and len( [key for key in self.in_keys if key not in observation_spec.keys(True)] ): raise ValueError( @@ -4229,7 +4227,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec "Make sure the environment has an observation_spec attribute that includes all the specs needed for CatTensor." ) - if not isinstance(observation_spec, CompositeSpec): + if not isinstance(observation_spec, Composite): # by def, there must be only one key return observation_spec @@ -4249,7 +4247,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec device = spec0.device shape[self.dim] = sum_shape shape = torch.Size(shape) - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=shape, dtype=spec0.dtype, device=device, @@ -4357,14 +4355,14 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: action = nn.functional.one_hot(action, self.num_actions_effective) return action - def transform_input_spec(self, input_spec: CompositeSpec): + def transform_input_spec(self, input_spec: Composite): input_spec = input_spec.clone() for key in input_spec["full_action_spec"].keys(True, True): key = ("full_action_spec", key) break else: raise KeyError("key not found in action_spec.") - input_spec[key] = OneHotDiscreteTensorSpec( + input_spec[key] = OneHot( self.max_actions, shape=(*input_spec[key].shape[:-1], self.max_actions), device=input_spec.device, @@ -4526,9 +4524,9 @@ class TensorDictPrimer(Transform): tensordict with the desired features. Args: - primers (dict or CompositeSpec, optional): a dictionary containing + primers (dict or Composite, optional): a dictionary containing key-spec pairs which will be used to populate the input tensordict. - :class:`~torchrl.data.CompositeSpec` instances are supported too. + :class:`~torchrl.data.Composite` instances are supported too. random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. @@ -4557,7 +4555,7 @@ class TensorDictPrimer(Transform): >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1")) >>> env = TransformedEnv(base_env) >>> # the env is batch-locked, so the leading dims of the spec must match those of the env - >>> env.append_transform(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 3]))) + >>> env.append_transform(TensorDictPrimer(mykey=Unbounded([2, 3]))) >>> td = env.reset() >>> print(td) TensorDict( @@ -4598,7 +4596,7 @@ class TensorDictPrimer(Transform): def __init__( self, - primers: dict | CompositeSpec = None, + primers: dict | Composite = None, random: bool | None = None, default_value: float | Callable @@ -4615,8 +4613,8 @@ def __init__( "as kwargs." ) kwargs = primers - if not isinstance(kwargs, CompositeSpec): - kwargs = CompositeSpec(kwargs) + if not isinstance(kwargs, Composite): + kwargs = Composite(kwargs) self.primers = kwargs if random and default_value: raise ValueError( @@ -4698,12 +4696,10 @@ def to(self, *args, **kwargs): def _expand_shape(self, spec): return spec.expand((*self.parent.batch_size, *spec.shape)) - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: - if not isinstance(observation_spec, CompositeSpec): + def transform_observation_spec(self, observation_spec: Composite) -> Composite: + if not isinstance(observation_spec, Composite): raise ValueError( - f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead." ) if self.primers.shape != observation_spec.shape: @@ -4866,7 +4862,7 @@ def __init__( ) random = state_dim is not None and action_dim is not None shape = tuple(shape) + tail_dim - primers = {"_eps_gSDE": UnboundedContinuousTensorSpec(shape=shape)} + primers = {"_eps_gSDE": Unbounded(shape=shape)} super().__init__(primers=primers, random=random, **kwargs) @@ -5325,8 +5321,8 @@ def __setstate__(self, state: Dict[str, Any]): @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if isinstance(observation_spec, BoundedTensorSpec): - return UnboundedContinuousTensorSpec( + if isinstance(observation_spec, Bounded): + return Unbounded( shape=observation_spec.shape, dtype=observation_spec.dtype, device=observation_spec.device, @@ -5540,13 +5536,13 @@ def _step( def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: state_spec = input_spec["full_state_spec"] if state_spec is None: - state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) + state_spec = Composite(shape=input_spec.shape, device=input_spec.device) state_spec.update(self._generate_episode_reward_spec()) input_spec["full_state_spec"] = state_spec return input_spec - def _generate_episode_reward_spec(self) -> CompositeSpec: - episode_reward_spec = CompositeSpec() + def _generate_episode_reward_spec(self) -> Composite: + episode_reward_spec = Composite() reward_spec = self.parent.full_reward_spec reward_spec_keys = self.parent.reward_keys # Define episode specs for all out_keys @@ -5559,7 +5555,7 @@ def _generate_episode_reward_spec(self) -> CompositeSpec: temp_rew_spec = reward_spec for sub_key in out_key[:-1]: if ( - not isinstance(temp_rew_spec, CompositeSpec) + not isinstance(temp_rew_spec, Composite) or sub_key not in temp_rew_spec.keys() ): break @@ -5580,8 +5576,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec """Transforms the observation spec, adding the new keys generated by RewardSum.""" if self.reward_spec: return observation_spec - if not isinstance(observation_spec, CompositeSpec): - observation_spec = CompositeSpec( + if not isinstance(observation_spec, Composite): + observation_spec = Composite( observation=observation_spec, shape=self.parent.batch_size ) observation_spec.update(self._generate_episode_reward_spec()) @@ -5844,12 +5840,10 @@ def _step( next_tensordict.set(truncated_key, truncated) return next_tensordict - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: - if not isinstance(observation_spec, CompositeSpec): + def transform_observation_spec(self, observation_spec: Composite) -> Composite: + if not isinstance(observation_spec, Composite): raise ValueError( - f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead." ) full_done_spec = self.parent.output_spec["full_done_spec"] for step_count_key in self.step_count_keys: @@ -5871,7 +5865,7 @@ def transform_observation_spec( raise KeyError( f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}." ) - observation_spec[step_count_key] = BoundedTensorSpec( + observation_spec[step_count_key] = Bounded( shape=shape, dtype=torch.int64, device=observation_spec.device, @@ -5880,7 +5874,7 @@ def transform_observation_spec( ) return super().transform_observation_spec(observation_spec) - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self.max_steps: full_done_spec = self.parent.output_spec["full_done_spec"] for truncated_key in self.truncated_keys: @@ -5902,7 +5896,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: raise KeyError( f"Could not find root of truncated_key {truncated_key} in done keys {self.done_keys}." ) - full_done_spec[truncated_key] = DiscreteTensorSpec( + full_done_spec[truncated_key] = Categorical( 2, dtype=torch.bool, device=output_spec.device, shape=shape ) if self.update_done: @@ -5925,19 +5919,19 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: raise KeyError( f"Could not find root of stop_key {done_key} in done keys {self.done_keys}." ) - full_done_spec[done_key] = DiscreteTensorSpec( + full_done_spec[done_key] = Categorical( 2, dtype=torch.bool, device=output_spec.device, shape=shape ) output_spec["full_done_spec"] = full_done_spec return super().transform_output_spec(output_spec) - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: - if not isinstance(input_spec, CompositeSpec): + def transform_input_spec(self, input_spec: Composite) -> Composite: + if not isinstance(input_spec, Composite): raise ValueError( - f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead." + f"input_spec was expected to be of type Composite. Got {type(input_spec)} instead." ) if input_spec["full_state_spec"] is None: - input_spec["full_state_spec"] = CompositeSpec( + input_spec["full_state_spec"] = Composite( shape=input_spec.shape, device=input_spec.device ) @@ -5962,9 +5956,7 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}." ) - input_spec[ - unravel_key(("full_state_spec", step_count_key)) - ] = BoundedTensorSpec( + input_spec[unravel_key(("full_state_spec", step_count_key))] = Bounded( shape=shape, dtype=torch.int64, device=input_spec.device, @@ -6051,7 +6043,7 @@ def _reset( return tensordict_reset.exclude(*self.excluded_keys) return tensordict - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if not self.inverse: full_done_spec = output_spec["full_done_spec"] full_reward_spec = output_spec["full_reward_spec"] @@ -6171,7 +6163,7 @@ def _reset( *self.selected_keys, *reward_keys, *done_keys, *input_keys, strict=False ) - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: full_done_spec = output_spec["full_done_spec"] full_reward_spec = output_spec["full_reward_spec"] full_observation_spec = output_spec["full_observation_spec"] @@ -6610,7 +6602,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec raise KeyError( f"Could not find root of init_key {init_key} within done_keys {self.parent.done_keys}." ) - observation_spec[init_key] = DiscreteTensorSpec( + observation_spec[init_key] = Categorical( 2, dtype=torch.bool, device=self.parent.device, @@ -6749,7 +6741,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: raise return tensordict - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: for done_key in self.parent.done_keys: if done_key in self.in_keys: for i, out_key in enumerate(self.out_keys): # noqa: B007 @@ -6791,7 +6783,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: del output_spec["full_observation_spec"][observation_key] return output_spec - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + def transform_input_spec(self, input_spec: Composite) -> Composite: for action_key in self.parent.action_keys: if action_key in self.in_keys: for i, out_key in enumerate(self.out_keys): # noqa: B007 @@ -7003,16 +6995,16 @@ class ActionMask(Transform): Examples: >>> import torch - >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec + >>> from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite >>> from torchrl.envs.transforms import ActionMask, TransformedEnv >>> from torchrl.envs.common import EnvBase >>> class MaskedEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) - ... self.action_spec = DiscreteTensorSpec(4) - ... self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)) - ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) - ... self.reward_spec = UnboundedContinuousTensorSpec(1) + ... self.action_spec = Categorical(4) + ... self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool)) + ... self.observation_spec = Composite(obs=Unbounded(3)) + ... self.reward_spec = Unbounded(1) ... ... def _reset(self, tensordict=None): ... td = self.observation_spec.rand() @@ -7048,10 +7040,10 @@ class ActionMask(Transform): """ ACCEPTED_SPECS = ( - OneHotDiscreteTensorSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - MultiDiscreteTensorSpec, + OneHot, + Categorical, + MultiOneHot, + MultiCategorical, ) SPEC_TYPE_ERROR = "The action spec must be one of {}. Got {} instead." @@ -7477,7 +7469,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - return BoundedTensorSpec( + return Bounded( shape=observation_spec.shape, device=observation_spec.device, dtype=observation_spec.dtype, @@ -7489,7 +7481,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: for key in self.in_keys: if key in self.parent.reward_keys: spec = self.parent.output_spec["full_reward_spec"][key] - self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec( + self.parent.output_spec["full_reward_spec"][key] = Bounded( shape=spec.shape, device=spec.device, dtype=spec.dtype, @@ -7512,31 +7504,31 @@ class RemoveEmptySpecs(Transform): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, \ - ... DiscreteTensorSpec + >>> from torchrl.data import Unbounded, Composite, \ + ... Categorical >>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs >>> >>> >>> class DummyEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) - ... self.observation_spec = CompositeSpec( - ... observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), - ... other=CompositeSpec( - ... another_other=CompositeSpec(shape=self.batch_size), + ... self.observation_spec = Composite( + ... observation=UnboundedContinuous((*self.batch_size, 3)), + ... other=Composite( + ... another_other=Composite(shape=self.batch_size), ... shape=self.batch_size, ... ), ... shape=self.batch_size, ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) - ... self.done_spec = DiscreteTensorSpec( + ... self.action_spec = UnboundedContinuous((*self.batch_size, 3)) + ... self.done_spec = Categorical( ... 2, (*self.batch_size, 1), dtype=torch.bool ... ) ... self.full_done_spec["truncated"] = self.full_done_spec[ ... "terminated"].clone() - ... self.reward_spec = CompositeSpec( - ... reward=UnboundedContinuousTensorSpec(*self.batch_size, 1), - ... other_reward=CompositeSpec(shape=self.batch_size), + ... self.reward_spec = Composite( + ... reward=UnboundedContinuous(*self.batch_size, 1), + ... other_reward=Composite(shape=self.batch_size), ... shape=self.batch_size ... ) ... @@ -7629,7 +7621,7 @@ def _sorter(key_val): return 0 return len(key) - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: full_done_spec = output_spec["full_done_spec"] full_reward_spec = output_spec["full_reward_spec"] full_observation_spec = output_spec["full_observation_spec"] @@ -7637,19 +7629,19 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: for key, spec in sorted( full_done_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): del full_done_spec[key] for key, spec in sorted( full_observation_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): del full_observation_spec[key] for key, spec in sorted( full_reward_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): del full_reward_spec[key] return output_spec @@ -7662,14 +7654,14 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: for key, spec in sorted( full_action_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): self._has_empty_input = True del full_action_spec[key] for key, spec in sorted( full_state_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): self._has_empty_input = True del full_state_spec[key] return input_spec @@ -7688,7 +7680,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: full_action_spec.items(True), key=self._sorter, reverse=True ): if ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and spec.is_empty() and key not in tensordict.keys(True) ): @@ -7698,7 +7690,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: full_state_spec.items(True), key=self._sorter, reverse=True ): if ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and spec.is_empty() and key not in tensordict.keys(True) ): @@ -7866,9 +7858,9 @@ class BatchSizeTransform(Transform): ... batch_locked = False ... def __init__(self): ... super().__init__() - ... self.observation_spec = CompositeSpec(observation=UnboundedContinuousTensorSpec(3)) - ... self.reward_spec = UnboundedContinuousTensorSpec(1) - ... self.action_spec = UnboundedContinuousTensorSpec(1) + ... self.observation_spec = Composite(observation=Unbounded(3)) + ... self.reward_spec = Unbounded(1) + ... self.action_spec = Unbounded(1) ... ... def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ... tensordict_batch_size = tensordict.batch_size if tensordict is not None else torch.Size([]) @@ -8016,12 +8008,12 @@ def transform_env_batch_size(self, batch_size: torch.Size): return self.batch_size return self.reshape_fn(torch.zeros(batch_size, device="meta")).shape - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self.batch_size is not None: return output_spec.expand(self.batch_size) return self.reshape_fn(output_spec) - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + def transform_input_spec(self, input_spec: Composite) -> Composite: if self.batch_size is not None: return input_spec.expand(self.batch_size) return self.reshape_fn(input_spec) @@ -8480,7 +8472,7 @@ def _indent(s): def transform_input_spec(self, input_spec): try: action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] - if not isinstance(action_spec, BoundedTensorSpec): + if not isinstance(action_spec, Bounded): raise TypeError( f"action spec type {type(action_spec)} is not supported." ) @@ -8539,9 +8531,9 @@ def custom_arange(nint): ] cls = ( - functools.partial(MultiDiscreteTensorSpec, remove_singleton=False) + functools.partial(MultiCategorical, remove_singleton=False) if self.categorical - else MultiOneHotDiscreteTensorSpec + else MultiOneHot ) if not isinstance(num_intervals, torch.Tensor): diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index d8bec1cf524..d394816372d 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -14,12 +14,7 @@ from torch import nn from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import ( - CompositeSpec, - DEVICE_TYPING, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import ( CenterCrop, Compose, @@ -198,8 +193,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: return out def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError("VC1Transform can only infer CompositeSpec") + if not isinstance(observation_spec, Composite): + raise ValueError("VC1Transform can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device @@ -211,7 +206,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, self.embd_size]), device=device ) diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index e814f5da476..556eacf579c 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -9,11 +9,7 @@ from tensordict import set_lazy_legacy, TensorDict, TensorDictBase from torch.hub import load_state_dict_from_url -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( CatTensors, @@ -92,8 +88,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: return out def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError("_VIPNet can only infer CompositeSpec") + if not isinstance(observation_spec, Composite): + raise ValueError("_VIPNet can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device @@ -105,7 +101,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, 1024]), device=device ) @@ -399,7 +395,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: if "full_state_spec" in input_spec.keys(): full_state_spec = input_spec["full_state_spec"] else: - full_state_spec = CompositeSpec( + full_state_spec = Composite( shape=input_spec.shape, device=input_spec.device ) # find the obs spec diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index de31ac99162..b723bd7b882 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -47,10 +47,10 @@ from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( - CompositeSpec, + Composite, NO_DEFAULT_RL as NO_DEFAULT, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.data.utils import check_no_exclusive_keys @@ -823,7 +823,7 @@ def check_env_specs( "you will need to first pass your stack through `torchrl.data.consolidate_spec`." ) if spec is None: - spec = CompositeSpec(shape=env.batch_size, device=env.device) + spec = Composite(shape=env.batch_size, device=env.device) td = last_td.select(*spec.keys(True, True), strict=True) if not spec.contains(td): raise AssertionError( @@ -835,7 +835,7 @@ def check_env_specs( ("obs", full_observation_spec), ): if spec is None: - spec = CompositeSpec(shape=env.batch_size, device=env.device) + spec = Composite(shape=env.batch_size, device=env.device) td = last_td.get("next").select(*spec.keys(True, True), strict=True) if not spec.contains(td): raise AssertionError( @@ -870,10 +870,10 @@ def _sort_keys(element): def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): - """Creates a CompositeSpec instance from a tensordict, assuming all values are unbounded. + """Creates a Composite instance from a tensordict, assuming all values are unbounded. Args: - data (tensordict.TensorDict): a tensordict to be mapped onto a CompositeSpec. + data (tensordict.TensorDict): a tensordict to be mapped onto a Composite. unsqueeze_null_shapes (bool, optional): if ``True``, every empty shape will be unsqueezed to (1,). Defaults to ``True``. @@ -886,25 +886,25 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): ... }, []) >>> spec = make_composite_from_td(data) >>> print(spec) - CompositeSpec( - obs: UnboundedContinuousTensorSpec( + Composite( + obs: UnboundedContinuous( shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous), - action: UnboundedContinuousTensorSpec( + action: UnboundedContinuous( shape=torch.Size([2]), space=None, device=cpu, dtype=torch.int32, domain=continuous), - next: CompositeSpec( - obs: UnboundedContinuousTensorSpec( + next: Composite( + obs: UnboundedContinuous( shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous), - reward: UnboundedContinuousTensorSpec( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox(low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> assert (spec.zero() == data.zero_()).all() """ # custom funtion to convert a tensordict in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( + else Unbounded( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape @@ -1094,14 +1094,14 @@ def _terminated_or_truncated( contained a ``True``. Examples: - >>> from torchrl.data.tensor_specs import DiscreteTensorSpec + >>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict - >>> spec = CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), - ... nested=CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), + >>> spec = Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), + ... nested=Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ @@ -1147,7 +1147,7 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()): composite_spec = {} found_leaf = 0 for eot_key, item in full_done_spec.items(): - if isinstance(item, CompositeSpec): + if isinstance(item, Composite): composite_spec[eot_key] = item else: found_leaf += 1 @@ -1219,14 +1219,14 @@ def terminated_or_truncated( contained a ``True``. Examples: - >>> from torchrl.data.tensor_specs import DiscreteTensorSpec + >>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict - >>> spec = CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), - ... nested=CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), + >>> spec = Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), + ... nested=Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ @@ -1274,7 +1274,7 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()): ) else: for eot_key, item in full_done_spec.items(): - if isinstance(item, CompositeSpec): + if isinstance(item, Composite): any_eot = any_eot | inner_terminated_or_truncated( data=data.get(eot_key), full_done_spec=item, @@ -1562,8 +1562,8 @@ class RandomPolicy: Examples: >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> from torchrl.data.tensor_specs import Bounded + >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) >>> actor = RandomPolicy(action_spec=action_spec) >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] """ @@ -1574,7 +1574,7 @@ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: - if isinstance(self.action_spec, CompositeSpec): + if isinstance(self.action_spec, Composite): return td.update(self.action_spec.rand()) else: return td.set(self.action_key, self.action_spec.rand()) diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 6d9e6fb3b49..abc0e3d3f95 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -45,20 +45,20 @@ class CEMPlanner(MPCPlannerBase): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) - ... self.state_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.state_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.observation_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.observation_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((1,)) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) + ... self.action_spec = Unbounded((1,)) + ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index 9c0bbc8f147..002094fb5d2 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -43,7 +43,7 @@ class MPPIPlanner(MPCPlannerBase): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ValueOperator @@ -51,14 +51,14 @@ class MPPIPlanner(MPCPlannerBase): >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) - ... self.state_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.state_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.observation_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.observation_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((1,)) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) + ... self.action_spec = Unbounded((1,)) + ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 81b7ec1e605..003c35cf0eb 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -22,7 +22,7 @@ from torch.distributions import Categorical from torchrl._utils import _replace_last -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _process_action_space_spec from torchrl.modules.tensordict_module.common import DistributionalDQNnet, SafeModule from torchrl.modules.tensordict_module.probabilistic import ( @@ -37,8 +37,8 @@ class Actor(SafeModule): The Actor class comes with default values for the out_keys (``["action"]``) and if the spec is provided but not as a - :class:`~torchrl.data.CompositeSpec` object, it will be - automatically translated into ``spec = CompositeSpec(action=spec)``. + :class:`~torchrl.data.Composite` object, it will be + automatically translated into ``spec = Composite(action=spec)``. Args: module (nn.Module): a :class:`~torch.nn.Module` used to map the input to @@ -70,11 +70,11 @@ class Actor(SafeModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import UnboundedContinuousTensorSpec + >>> from torchrl.data import Unbounded >>> from torchrl.modules import Actor >>> torch.manual_seed(0) >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = UnboundedContinuousTensorSpec(4) + >>> action_spec = Unbounded(4) >>> module = torch.nn.Linear(4, 4) >>> td_module = Actor( ... module=module, @@ -111,9 +111,9 @@ def __init__( if ( "action" in out_keys and spec is not None - and not isinstance(spec, CompositeSpec) + and not isinstance(spec, Composite) ): - spec = CompositeSpec(action=spec) + spec = Composite(action=spec) super().__init__( module, @@ -128,8 +128,8 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): """General class for probabilistic actors in RL. The Actor class comes with default values for the out_keys (["action"]) - and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into :obj:`spec = CompositeSpec(action=spec)` + and if the spec is provided but not as a Composite object, it will be + automatically translated into :obj:`spec = Composite(action=spec)` Args: module (nn.Module): a :class:`torch.nn.Module` used to map the input to @@ -205,10 +205,10 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]), + >>> action_spec = Bounded(shape=torch.Size([4]), ... low=-1, high=1) >>> module = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -382,12 +382,8 @@ def __init__( out_keys = list(distribution_map.keys()) else: out_keys = ["action"] - if ( - len(out_keys) == 1 - and spec is not None - and not isinstance(spec, CompositeSpec) - ): - spec = CompositeSpec({out_keys[0]: spec}) + if len(out_keys) == 1 and spec is not None and not isinstance(spec, Composite): + spec = Composite({out_keys[0]: spec}) super().__init__( module, @@ -424,7 +420,7 @@ class ValueOperator(TensorDictModule): >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import UnboundedContinuousTensorSpec + >>> from torchrl.data import Unbounded >>> from torchrl.modules import ValueOperator >>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,]) >>> class CustomModule(nn.Module): @@ -577,22 +573,22 @@ def __init__( ) self.out_keys = out_keys action_key = out_keys[0] - if not isinstance(spec, CompositeSpec): - spec = CompositeSpec({action_key: spec}) + if not isinstance(spec, Composite): + spec = Composite({action_key: spec}) super().__init__() self.register_spec(safe=safe, spec=spec) register_spec = SafeModule.register_spec @property - def spec(self) -> CompositeSpec: + def spec(self) -> Composite: return self._spec @spec.setter - def spec(self, spec: CompositeSpec) -> None: - if not isinstance(spec, CompositeSpec): + def spec(self, spec: Composite) -> None: + if not isinstance(spec, Composite): raise RuntimeError( - f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance." + f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance." ) self._spec = spec @@ -891,13 +887,13 @@ class QValueHook: >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) >>> hook = QValueHook("one_hot") >>> module.register_forward_hook(hook) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) >>> td = qvalue_actor(td) >>> print(td) @@ -975,7 +971,7 @@ class DistributionalQValueHook(QValueHook): >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 @@ -989,7 +985,7 @@ class DistributionalQValueHook(QValueHook): ... >>> module = CustomDistributionalQval() >>> params = TensorDict.from_module(module) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) >>> module.register_forward_hook(hook) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) @@ -1085,12 +1081,12 @@ class QValueActor(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> # with a regular nn.Module >>> module = nn.Linear(4, 4) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) @@ -1106,7 +1102,7 @@ class QValueActor(SafeSequential): >>> # with a TensorDictModule >>> td = TensorDict({'obs': torch.randn(5, 4)}, [5]) >>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"]) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) @@ -1161,13 +1157,13 @@ def __init__( module, in_keys=in_keys, out_keys=[action_value_key] ) if spec is None: - spec = CompositeSpec() - if isinstance(spec, CompositeSpec): + spec = Composite() + if isinstance(spec, Composite): spec = spec.clone() if "action" not in spec.keys(): spec["action"] = None else: - spec = CompositeSpec(action=spec, shape=spec.shape[:-1]) + spec = Composite(action=spec, shape=spec.shape[:-1]) spec[action_value_key] = None spec["chosen_action_value"] = None qvalue = QValueModule( @@ -1237,7 +1233,7 @@ class DistributionalQValueActor(QValueActor): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 @@ -1247,7 +1243,7 @@ class DistributionalQValueActor(QValueActor): ... TensorDictModule(module, ["observation"], ["action_value"]), ... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]), ... ) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor( ... module=module, ... spec=action_spec, @@ -1299,13 +1295,13 @@ def __init__( module, in_keys=in_keys, out_keys=[action_value_key] ) if spec is None: - spec = CompositeSpec() - if isinstance(spec, CompositeSpec): + spec = Composite() + if isinstance(spec, Composite): spec = spec.clone() if "action" not in spec.keys(): spec["action"] = None else: - spec = CompositeSpec(action=spec, shape=spec.shape[:-1]) + spec = Composite(action=spec, shape=spec.shape[:-1]) spec[action_value_key] = None qvalue = DistributionalQValueModule( @@ -1848,8 +1844,8 @@ def __init__( self.return_to_go_key = "return_to_go" self.inference_context = inference_context if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({self.action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({self.action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() @@ -1860,7 +1856,7 @@ def __init__( if self.action_key not in self._spec.keys(): self._spec[self.action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = Composite({key: None for key in policy.out_keys}) self.checked = False @property @@ -1989,7 +1985,7 @@ class TanhModule(TensorDictModuleBase): Keyword Args: spec (TensorSpec, optional): if provided, the spec of the output. - If a CompositeSpec is provided, its key(s) must match the key(s) + If a Composite is provided, its key(s) must match the key(s) in out_keys. Otherwise, the key(s) of out_keys are assumed and the same spec is used for all outputs. low (float, np.ndarray or torch.Tensor): the lower bound of the space. @@ -2027,8 +2023,8 @@ class TanhModule(TensorDictModuleBase): >>> data['action'] tensor([-2.0000, 0.9991, 1.0000, -2.0000, -1.9991]) >>> # A spec can be provided - >>> from torchrl.data import BoundedTensorSpec - >>> spec = BoundedTensorSpec(low, high, shape=()) + >>> from torchrl.data import Bounded + >>> spec = Bounded(low, high, shape=()) >>> mod = TanhModule( ... in_keys=in_keys, ... low=low, @@ -2038,9 +2034,9 @@ class TanhModule(TensorDictModuleBase): ... ) >>> # One can also work with multiple keys >>> in_keys = ['a', 'b'] - >>> spec = CompositeSpec( - ... a=BoundedTensorSpec(-3, 0, shape=()), - ... b=BoundedTensorSpec(0, 3, shape=())) + >>> spec = Composite( + ... a=Bounded(-3, 0, shape=()), + ... b=Bounded(0, 3, shape=())) >>> mod = TanhModule( ... in_keys=in_keys, ... spec=spec, @@ -2077,13 +2073,13 @@ def __init__( ) self.out_keys = out_keys # action_spec can be a composite spec or not - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): for out_key in self.out_keys: if out_key not in spec.keys(True, True): spec[out_key] = None else: # if one spec is present, we assume it is the same for all keys - spec = CompositeSpec( + spec = Composite( {out_key: spec for out_key in out_keys}, ) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 11cc363b461..c9853c378e7 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -21,7 +21,7 @@ from torch import nn from torch.nn import functional as F -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import DEVICE_TYPING @@ -59,12 +59,12 @@ def _check_all_str(list_of_str, first_level=True): def _forward_hook_safe_action(module, tensordict_in, tensordict_out): try: spec = module.spec - if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): + if len(module.out_keys) > 1 and not isinstance(spec, Composite): raise RuntimeError( - "safe TensorDictModules with multiple out_keys require a CompositeSpec with matching keys. Got " + "safe TensorDictModules with multiple out_keys require a Composite with matching keys. Got " f"keys {module.out_keys}." ) - elif not isinstance(spec, CompositeSpec): + elif not isinstance(spec, Composite): out_key = module.out_keys[0] keys = [out_key] values = [spec] @@ -138,10 +138,10 @@ class SafeModule(TensorDictModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import UnboundedContinuousTensorSpec + >>> from torchrl.data import Unbounded >>> from torchrl.modules import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) - >>> spec = UnboundedContinuousTensorSpec(8) + >>> spec = Unbounded(8) >>> module = torch.nn.GRUCell(4, 8) >>> td_fmodule = TensorDictModule( ... module=module, @@ -216,18 +216,18 @@ def register_spec(self, safe, spec): spec = spec.clone() if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") - elif spec is not None and not isinstance(spec, CompositeSpec): + elif spec is not None and not isinstance(spec, Composite): if len(self.out_keys) > 1: raise RuntimeError( f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. " - "Consider using a CompositeSpec object or no spec at all." + "Consider using a Composite object or no spec at all." ) - spec = CompositeSpec({self.out_keys[0]: spec}) - elif spec is not None and isinstance(spec, CompositeSpec): + spec = Composite({self.out_keys[0]: spec}) + elif spec is not None and isinstance(spec, Composite): if "_" in spec.keys() and spec["_"] is not None: warnings.warn('got a spec with key "_": it will be ignored') elif spec is None: - spec = CompositeSpec() + spec = Composite() # unravel_key_list(self.out_keys) can be removed once 473 is merged in tensordict spec_keys = set(unravel_key_list(list(spec.keys(True, True)))) @@ -247,7 +247,7 @@ def register_spec(self, safe, spec): self.safe = safe if safe: if spec is None or ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and all(_spec is None for _spec in spec.values()) ): raise RuntimeError( @@ -257,14 +257,14 @@ def register_spec(self, safe, spec): self.register_forward_hook(_forward_hook_safe_action) @property - def spec(self) -> CompositeSpec: + def spec(self) -> Composite: return self._spec @spec.setter - def spec(self, spec: CompositeSpec) -> None: - if not isinstance(spec, CompositeSpec): + def spec(self, spec: Composite) -> None: + if not isinstance(spec, Composite): raise RuntimeError( - f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance." + f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance." ) self._spec = spec diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 5a41f11bf76..3b19b60048a 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -16,7 +16,7 @@ ) from tensordict.utils import expand_as_right, expand_right, NestedKey -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.tensordict_module.common import _forward_hook_safe_action @@ -64,9 +64,9 @@ class EGreedyModule(TensorDictModuleBase): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential >>> from torchrl.modules import EGreedyModule, Actor - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2)) @@ -115,8 +115,8 @@ def __init__( self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec @property @@ -155,7 +155,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: cond = expand_as_right(cond, out) spec = self.spec if spec is not None: - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): spec = spec[self.action_key] if spec.shape != out.shape: # In batched envs if the spec is passed unbatched, the rand() will not @@ -214,9 +214,9 @@ class EGreedyWrapper(TensorDictModuleWrapper): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import EGreedyWrapper, Actor - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) @@ -267,7 +267,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): mean (float, optional): mean of each output element’s normal distribution. std (float, optional): standard deviation of each output element’s normal distribution. action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to + its output spec will be of type Composite. One needs to know where to find the action spec. Default is "action". spec (TensorSpec, optional): if provided, the sampled action will be @@ -323,8 +323,8 @@ def __init__( f"The action key {action_key} was not found in the td_module out_keys {self.td_module.out_keys}." ) if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() @@ -335,7 +335,7 @@ def __init__( if action_key not in self._spec.keys(True, True): self._spec[action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = Composite({key: None for key in policy.out_keys}) self.safe = safe if self.safe: @@ -410,7 +410,7 @@ class AdditiveGaussianModule(TensorDictModuleBase): Keyword Args: action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to + its output spec will be of type Composite. One needs to know where to find the action spec. default: "action" @@ -453,8 +453,8 @@ def __init__( self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) else: raise RuntimeError("spec cannot be None.") self._spec = spec @@ -570,10 +570,10 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(module=module, spec=spec) >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy) @@ -647,8 +647,8 @@ def __init__( steps_key = self.ou.steps_key if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() @@ -659,7 +659,7 @@ def __init__( if action_key not in self._spec.keys(True, True): self._spec[action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = Composite({key: None for key in policy.out_keys}) ou_specs = { noise_key: None, steps_key: None, @@ -783,10 +783,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(module=module, spec=spec) >>> ou = OrnsteinUhlenbeckProcessModule(spec=spec) @@ -851,8 +851,8 @@ def __init__( steps_key = self.ou.steps_key if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec else: raise RuntimeError("spec cannot be None.") diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 725323e1a28..4b38b19c699 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -15,7 +15,7 @@ TensorDictModule, ) from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.modules.distributions import Delta from torchrl.modules.tensordict_module.common import _forward_hook_safe_action from torchrl.modules.tensordict_module.sequence import SafeSequential @@ -129,18 +129,18 @@ def __init__( spec = spec.clone() if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") - elif spec is not None and not isinstance(spec, CompositeSpec): + elif spec is not None and not isinstance(spec, Composite): if len(self.out_keys) > 1: raise RuntimeError( f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. " - "Consider using a CompositeSpec object or no spec at all." + "Consider using a Composite object or no spec at all." ) - spec = CompositeSpec({self.out_keys[0]: spec}) - elif spec is not None and isinstance(spec, CompositeSpec): + spec = Composite({self.out_keys[0]: spec}) + elif spec is not None and isinstance(spec, Composite): if "_" in spec.keys(): warnings.warn('got a spec with key "_": it will be ignored') elif spec is None: - spec = CompositeSpec() + spec = Composite() spec_keys = set(unravel_key_list(list(spec.keys(True, True)))) out_keys = set(unravel_key_list(self.out_keys)) if spec_keys != out_keys: @@ -159,7 +159,7 @@ def __init__( self.safe = safe if safe: if spec is None or ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and all(_spec is None for _spec in spec.values()) ): raise RuntimeError( @@ -169,14 +169,14 @@ def __init__( self.register_forward_hook(_forward_hook_safe_action) @property - def spec(self) -> CompositeSpec: + def spec(self) -> Composite: return self._spec @spec.setter - def spec(self, spec: CompositeSpec) -> None: - if not isinstance(spec, CompositeSpec): + def spec(self, spec: Composite) -> None: + if not isinstance(spec, Composite): raise RuntimeError( - f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance." + f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance." ) self._spec = spec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 048ddedbf9d..657bf6649d7 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -16,7 +16,7 @@ from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase -from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( _inv_pad_sequence, _split_and_pad_sequence, @@ -581,12 +581,8 @@ def make_tuple(key): ) return TensorDictPrimer( { - in_key1: UnboundedContinuousTensorSpec( - shape=(self.lstm.num_layers, self.lstm.hidden_size) - ), - in_key2: UnboundedContinuousTensorSpec( - shape=(self.lstm.num_layers, self.lstm.hidden_size) - ), + in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), + in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), } ) @@ -1329,9 +1325,7 @@ def make_tuple(key): ) return TensorDictPrimer( { - in_key1: UnboundedContinuousTensorSpec( - shape=(self.gru.num_layers, self.gru.hidden_size) - ), + in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)), } ) diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 41ddb55fb35..938843e624f 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -8,7 +8,7 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.modules.tensordict_module.common import SafeModule @@ -33,11 +33,11 @@ class SafeSequential(TensorDictSequential, SafeModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) - >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None) + >>> spec1 = Composite(hidden=Unbounded(4), loc=None, scale=None) >>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( @@ -48,7 +48,7 @@ class SafeSequential(TensorDictSequential, SafeModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) - >>> spec2 = UnboundedContinuousTensorSpec(8) + >>> spec2 = Unbounded(8) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = TensorDictModule( ... module=module2, @@ -74,12 +74,12 @@ class SafeSequential(TensorDictSequential, SafeModule): is_shared=False) >>> # The module spec aggregates all the input specs: >>> print(td_module.spec) - CompositeSpec( - hidden: UnboundedContinuousTensorSpec( + Composite( + hidden: UnboundedContinuous( shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous), loc: None, scale: None, - output: UnboundedContinuousTensorSpec( + output: UnboundedContinuous( shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous)) In the vmap case: @@ -112,12 +112,12 @@ def __init__( in_keys, out_keys = self._compute_in_and_out_keys(modules) - spec = CompositeSpec() + spec = Composite() for module in modules: try: spec.update(module.spec) except AttributeError: - spec.update(CompositeSpec({key: None for key in module.out_keys})) + spec.update(Composite({key: None for key in module.out_keys})) super(TensorDictSequential, self).__init__( spec=spec, diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 0f3088a8943..9a8914aab89 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -46,8 +46,8 @@ def get_primers_from_module(module): >>> primers = get_primers_from_module(model) >>> print(primers) - TensorDictPrimer(primers=CompositeSpec( - recurrent_state: UnboundedContinuousTensorSpec( + TensorDictPrimer(primers=Composite( + recurrent_state: UnboundedContinuous( shape=torch.Size([1, 10]), space=None, device=cpu, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index a236b80d56c..d3b2b4d2ac2 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -96,14 +96,14 @@ class A2CLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -147,14 +147,14 @@ class A2CLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f1e2aa9c532..6a6cf8548e4 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -19,7 +19,7 @@ from tensordict.utils import NestedKey, unravel_key from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -100,14 +100,14 @@ class CQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -160,14 +160,14 @@ class CQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -405,8 +405,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -933,11 +933,11 @@ class DiscreteCQLLoss(LossModule): Examples: >>> from torchrl.modules import MLP, QValueActor - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.objectives import DiscreteCQLLoss >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DiscreteCQLLoss(actor, action_space=spec) >>> batch = [10,] @@ -969,12 +969,12 @@ class DiscreteCQLLoss(LossModule): Examples: >>> from torchrl.objectives import DiscreteCQLLoss - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 - >>> action_spec = OneHotDiscreteTensorSpec(n_action) + >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec) >>> # define data diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index e76e3438c09..d86442fca12 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -15,7 +15,7 @@ from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule @@ -98,14 +98,14 @@ class CrossQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.crossq import CrossQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -156,14 +156,14 @@ class CrossQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives import CrossQLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -375,8 +375,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 6e1cf0f5eb3..7dc6b23212a 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -50,12 +50,12 @@ class DDPGLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): @@ -100,12 +100,12 @@ class DDPGLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index c1ed8b2cffe..4f805c1b411 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -17,7 +17,7 @@ from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule @@ -251,8 +251,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) target_entropy = -float( np.prod(action_spec[self.tensor_keys.action].shape) ) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 7b35598c474..1f3ec714f53 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -52,9 +52,9 @@ class DQNLoss(LossModule): https://arxiv.org/abs/1509.06461. Defaults to ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). If not provided, an attempt to retrieve it from the value network will be made. priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] @@ -68,10 +68,10 @@ class DQNLoss(LossModule): Examples: >>> from torchrl.modules import MLP - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DQNLoss(actor, action_space=spec) >>> batch = [10,] @@ -99,12 +99,12 @@ class DQNLoss(LossModule): Examples: >>> from torchrl.objectives import DQNLoss - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 - >>> action_spec = OneHotDiscreteTensorSpec(n_action) + >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dqn_loss = DQNLoss(value_network, action_space=action_spec) >>> # define data diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 74cfe504e78..04d7e020551 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -73,14 +73,14 @@ class IQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -136,14 +136,14 @@ class IQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -541,9 +541,9 @@ class DiscreteIQLLoss(IQLLoss): Keyword Args: action_space (str or TensorSpec): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. loss_function (str, optional): loss function to be used with @@ -569,14 +569,14 @@ class DiscreteIQLLoss(IQLLoss): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions.discrete import OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, @@ -627,14 +627,14 @@ class DiscreteIQLLoss(IQLLoss): >>> import torch >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions.discrete import OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index c9dc281ef41..ce4cc8ddbb8 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -63,9 +63,9 @@ class QMixerLoss(LossModule): create a double DQN. Default is ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). If not provided, an attempt to retrieve it from the value network will be made. priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index c29bc73dfa8..d79f0b2ea84 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -151,14 +151,14 @@ class PPOLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -204,13 +204,13 @@ class PPOLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 1522fd7749e..cda2c62894e 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -16,7 +16,7 @@ from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule @@ -93,14 +93,14 @@ class REDQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -155,13 +155,13 @@ class REDQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -367,8 +367,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index af9f7d99b46..08ff896610c 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -100,7 +100,7 @@ class ReinforceLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + >>> from torchrl.data.tensor_specs import Unbounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule @@ -115,7 +115,7 @@ class ReinforceLoss(LossModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... in_keys=["loc", "scale"], - ... spec=UnboundedContinuousTensorSpec(n_act),) + ... spec=Unbounded(n_act),) >>> loss = ReinforceLoss(actor_net, value_net) >>> batch = 2 >>> data = TensorDict({ @@ -146,7 +146,7 @@ class ReinforceLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + >>> from torchrl.data.tensor_specs import Unbounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule @@ -160,7 +160,7 @@ class ReinforceLoss(LossModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... in_keys=["loc", "scale"], - ... spec=UnboundedContinuousTensorSpec(n_act),) + ... spec=Unbounded(n_act),) >>> loss = ReinforceLoss(actor_net, value_net) >>> batch = 2 >>> loss_actor, loss_value = loss( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index df444eac053..6e57a927f37 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -18,7 +18,7 @@ from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor @@ -117,14 +117,14 @@ class SACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -180,14 +180,14 @@ class SACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -440,8 +440,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -818,9 +818,9 @@ class DiscreteSACLoss(LossModule): qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. action_space (str or TensorSpec): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). num_actions (int, optional): number of actions in the action space. To be provided if target_entropy is set to "auto". num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 2. @@ -852,7 +852,7 @@ class DiscreteSACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule @@ -860,7 +860,7 @@ class DiscreteSACLoss(LossModule): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, @@ -909,13 +909,13 @@ class DiscreteSACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index eb1027ad936..922d6df7a74 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -12,7 +12,7 @@ from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule @@ -83,14 +83,14 @@ class TD3Loss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3 import TD3Loss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -139,11 +139,11 @@ class TD3Loss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.td3 import TD3Loss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -283,7 +283,7 @@ def __init__( f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." ) elif action_spec is not None: - if isinstance(action_spec, CompositeSpec): + if isinstance(action_spec, Composite): if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -296,9 +296,9 @@ def __init__( action_spec = action_spec[self.tensor_keys.action][ (0,) * len(action_container_shape) ] - if not isinstance(action_spec, BoundedTensorSpec): + if not isinstance(action_spec, Bounded): raise ValueError( - f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + f"action_spec is not of type Bounded but {type(action_spec)}." ) low = action_spec.space.low high = action_spec.space.high diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index aa87ea9aa1a..cd40ac1e029 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -12,7 +12,7 @@ from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule @@ -94,14 +94,14 @@ class TD3BCLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -152,11 +152,11 @@ class TD3BCLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -299,7 +299,7 @@ def __init__( f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." ) elif action_spec is not None: - if isinstance(action_spec, CompositeSpec): + if isinstance(action_spec, Composite): if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -312,9 +312,9 @@ def __init__( action_spec = action_spec[self.tensor_keys.action][ (0,) * len(action_container_shape) ] - if not isinstance(action_spec, BoundedTensorSpec): + if not isinstance(action_spec, Bounded): raise ValueError( - f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + f"action_spec is not of type Bounded but {type(action_spec)}." ) low = action_spec.space.low high = action_spec.space.high diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index b7fb8ab4ed2..73e3b5bdaab 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -18,7 +18,7 @@ from torchrl._utils import _can_be_pickled from torchrl.data import TensorSpec -from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import NonTensor, Unbounded from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import EnvBase from torchrl.envs.transforms import ObservationTransform, Transform @@ -506,11 +506,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec self._call(td_in) obs = td_in.get(self.out_keys[0]) if isinstance(obs, NonTensorData): - spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape) + spec = NonTensor(device=obs.device, dtype=obs.dtype, shape=obs.shape) else: - spec = UnboundedContinuousTensorSpec( - device=obs.device, dtype=obs.dtype, shape=obs.shape - ) + spec = Unbounded(device=obs.device, dtype=obs.dtype, shape=obs.shape) observation_spec[self.out_keys[0]] = spec if switch: self.switch() diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 0c9ec92cff4..4bae738101d 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -9,11 +9,7 @@ from tensordict import set_lazy_legacy from tensordict.nn import InteractionType from torch import nn -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.model_based.dreamer import DreamerEnv @@ -153,7 +149,7 @@ def make_dqn_actor( actor_class = QValueActor actor_kwargs = {} - if isinstance(action_spec, DiscreteTensorSpec): + if isinstance(action_spec, Categorical): # if action spec is modeled as categorical variable, we still need to have features equal # to the number of possible choices and also set categorical behavioural for actors. actor_kwargs.update({"action_space": "categorical"}) @@ -182,7 +178,7 @@ def make_dqn_actor( model = actor_class( module=net, - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=[in_key], safe=True, **actor_kwargs, @@ -385,13 +381,13 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), @@ -404,7 +400,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec(**{action_key: proof_environment.action_spec}), + spec=Composite(**{action_key: proof_environment.action_spec}), ), ) return actor_simulator @@ -436,12 +432,12 @@ def _dreamer_make_actor_real( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, ), } @@ -453,9 +449,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec( - **{action_key: proof_environment.action_spec.to("cpu")} - ), + spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}), ), ), SafeModule( @@ -536,8 +530,8 @@ def _dreamer_make_mbenv( model_based_env.set_specs_from_env(proof_environment) model_based_env = TransformedEnv(model_based_env) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), # "action": proof_environment.action_spec, } model_based_env.append_transform( diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 1bf7fd57e83..869f0f980b3 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -683,7 +683,7 @@ def get_env_stats(): ) -from torchrl.data import CompositeSpec +from torchrl.data import Composite ############################################################################### # Building the model @@ -756,7 +756,7 @@ def make_ddpg_actor( actor, distribution_class=TanhDelta, in_keys=["param"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), ).to(device) q_net = DdpgMlpQNet() diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index d25bc2cdd8a..19f79c37480 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -107,7 +107,7 @@ from tensordict.nn import TensorDictModule from torch import nn -from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data import Bounded, Composite, Unbounded from torchrl.envs import ( CatTensors, EnvBase, @@ -410,14 +410,14 @@ def _reset(self, tensordict): def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] - self.observation_spec = CompositeSpec( - th=BoundedTensorSpec( + self.observation_spec = Composite( + th=Bounded( low=-torch.pi, high=torch.pi, shape=(), dtype=torch.float32, ), - thdot=BoundedTensorSpec( + thdot=Bounded( low=-td_params["params", "max_speed"], high=td_params["params", "max_speed"], shape=(), @@ -433,25 +433,23 @@ def _make_spec(self, td_params): self.state_spec = self.observation_spec.clone() # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-td_params["params", "max_torque"], high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + self.reward_spec = Unbounded(shape=(*td_params.shape, 1)) def make_composite_from_td(td): # custom function to convert a ``tensordict`` in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) + else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape) for key, tensor in td.items() }, shape=td.shape, @@ -694,7 +692,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, @@ -718,7 +716,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 29192d1c10e..6cec838fdc2 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -572,10 +572,10 @@ def exec_sequence(params, data): # ------------------------------ torch.manual_seed(0) -from torchrl.data import BoundedTensorSpec +from torchrl.data import Bounded from torchrl.modules import SafeModule -spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) +spec = Bounded(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True