diff --git a/test/test_collector.py b/test/test_collector.py index 87393165450..a5d147ebcad 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -212,6 +212,7 @@ def env_fn(seed): device, policy_device, storing_device, d.device.type ) break + assert d.names[-1] == "time" collector.shutdown() @@ -231,6 +232,7 @@ def env_fn(seed): device, policy_device, storing_device, d.device.type ) break + assert d.names[-1] == "time" ccollector.shutdown() @@ -273,6 +275,7 @@ def env_fn(seed): b2 = d else: break + assert d.names[-1] == "time" with pytest.raises(AssertionError): assert_allclose_td(b1, b2) collector.shutdown() @@ -292,6 +295,7 @@ def env_fn(seed): b2c = d else: break + assert d.names[-1] == "time" with pytest.raises(AssertionError): assert_allclose_td(b1c, b2c) @@ -508,6 +512,7 @@ def env_fn(): assert b.numel() == -(-frames_per_batch // num_env) * num_env if i == 5: break + assert b.names[-1] == "time" ccollector.shutdown() ccollector = MultiSyncDataCollector( @@ -525,6 +530,7 @@ def env_fn(): ) if i == 5: break + assert b.names[-1] == "time" ccollector.shutdown() diff --git a/test/test_distributed.py b/test/test_distributed.py index fc826d2911f..9b18c709436 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -100,6 +100,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): for data in collector: total += data.numel() assert data.numel() == frames_per_batch + assert data.names[-1] == "time" collector.shutdown() assert total == 1000 queue.put("passed") diff --git a/test/test_env.py b/test/test_env.py index 2f84037f7b1..a44237f9da4 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -156,12 +156,14 @@ def test_rollout(env_name, frame_skip, seed=0): env.set_seed(seed) env.reset() rollout1 = env.rollout(max_steps=100) + assert rollout1.names[-1] == "time" torch.manual_seed(seed) np.random.seed(seed) env.set_seed(seed) env.reset() rollout2 = env.rollout(max_steps=100) + assert rollout2.names[-1] == "time" assert_allclose_td(rollout1, rollout2) @@ -231,6 +233,7 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0): env = SerialEnv(3, envs) env.set_seed(100) out = env.rollout(100, break_when_any_done=False) + assert out.names[-1] == "time" assert out.shape == torch.Size([3, 100]) assert ( out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2]) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3d37f22ed06..2c15ab2db84 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -464,6 +464,12 @@ class SyncDataCollector(DataCollectorBase): is_shared=False) >>> del collector + The collector delivers batches of data that are marked with a ``"time"`` + dimension. + + Examples: + >>> assert data.names[-1] == "time" + """ def __init__( @@ -665,6 +671,7 @@ def __init__( device=self.storing_device, ), ) + self._tensordict_out.refine_names(..., "time") if split_trajs is None: split_trajs = False diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b0442bdbfa4..23b98871a18 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -15,16 +15,16 @@ import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase +from torchrl._utils import prod, seed_generator + from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, ) - -from .._utils import prod, seed_generator -from ..data.utils import DEVICE_TYPING -from .utils import get_available_libraries, step_mdp +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs.utils import get_available_libraries, step_mdp LIBRARIES = get_available_libraries() @@ -219,6 +219,13 @@ def batch_size(self) -> TensorSpec: def batch_size(self, value: torch.Size) -> None: self._batch_size = torch.Size(value) + def ndimension(self): + return len(self.batch_size) + + @property + def ndim(self): + return self.ndimension() + # Parent specs: input and output spec. @property def input_spec(self) -> TensorSpec: @@ -661,6 +668,97 @@ def rollout( Returns: TensorDict object containing the resulting trajectory. + The data returned will be marked with a "time" dimension name for the last + dimension of the tensordict (at the ``env.ndim`` index). + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs.transforms import TransformedEnv, StepCounter + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)) + >>> rollout = env.rollout(max_steps=1000) + >>> print(rollout) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([20]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([20]), + device=cpu, + is_shared=False) + >>> print(rollout.names) + ['time'] + >>> # with envs that contain more dimensions + >>> from torchrl.envs import SerialEnv + >>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))) + >>> rollout = env.rollout(max_steps=1000) + >>> print(rollout) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 20]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 20]), + device=cpu, + is_shared=False) + >>> print(rollout.names) + [None, 'time'] + + In some instances, contiguous tensordict cannot be obtained because + they cannot be stacked. This can happen when the data returned at + each step may have a different shape, or when different environments + are executed together. In that case, ``return_contiguous=False`` + will cause the returned tensordict to be a lazy stack of tensordicts: + + Examples: + >>> rollout = env.rollout(4, return_contiguous=False) + >>> print(rollout) + LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=cpu, + is_shared=False) + >>> print(rollout.names) + [None, 'time'] + """ try: policy_device = next(policy.parameters()).device @@ -718,6 +816,7 @@ def policy(td): batch_size = self.batch_size if tensordict is None else tensordict.batch_size out_td = torch.stack(tensordicts, len(batch_size)) + out_td.refine_names(..., "time") if return_contiguous: return out_td.contiguous() return out_td