Skip to content

Commit

Permalink
[Feature] Marking the time dimension (pytorch#1095)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Nigam <rohitnigam@meta.com>
Co-authored-by: Rohit Nigam <rohitnigam@gmail.com>
  • Loading branch information
3 people authored May 5, 2023
1 parent 39fe662 commit 99a95e3
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 4 deletions.
6 changes: 6 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def env_fn(seed):
device, policy_device, storing_device, d.device.type
)
break
assert d.names[-1] == "time"

collector.shutdown()

Expand All @@ -231,6 +232,7 @@ def env_fn(seed):
device, policy_device, storing_device, d.device.type
)
break
assert d.names[-1] == "time"

ccollector.shutdown()

Expand Down Expand Up @@ -273,6 +275,7 @@ def env_fn(seed):
b2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(b1, b2)
collector.shutdown()
Expand All @@ -292,6 +295,7 @@ def env_fn(seed):
b2c = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(b1c, b2c)

Expand Down Expand Up @@ -508,6 +512,7 @@ def env_fn():
assert b.numel() == -(-frames_per_batch // num_env) * num_env
if i == 5:
break
assert b.names[-1] == "time"
ccollector.shutdown()

ccollector = MultiSyncDataCollector(
Expand All @@ -525,6 +530,7 @@ def env_fn():
)
if i == 5:
break
assert b.names[-1] == "time"
ccollector.shutdown()


Expand Down
1 change: 1 addition & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch):
for data in collector:
total += data.numel()
assert data.numel() == frames_per_batch
assert data.names[-1] == "time"
collector.shutdown()
assert total == 1000
queue.put("passed")
Expand Down
3 changes: 3 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ def test_rollout(env_name, frame_skip, seed=0):
env.set_seed(seed)
env.reset()
rollout1 = env.rollout(max_steps=100)
assert rollout1.names[-1] == "time"

torch.manual_seed(seed)
np.random.seed(seed)
env.set_seed(seed)
env.reset()
rollout2 = env.rollout(max_steps=100)
assert rollout2.names[-1] == "time"

assert_allclose_td(rollout1, rollout2)

Expand Down Expand Up @@ -231,6 +233,7 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0):
env = SerialEnv(3, envs)
env.set_seed(100)
out = env.rollout(100, break_when_any_done=False)
assert out.names[-1] == "time"
assert out.shape == torch.Size([3, 100])
assert (
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])
Expand Down
7 changes: 7 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ class SyncDataCollector(DataCollectorBase):
is_shared=False)
>>> del collector
The collector delivers batches of data that are marked with a ``"time"``
dimension.
Examples:
>>> assert data.names[-1] == "time"
"""

def __init__(
Expand Down Expand Up @@ -665,6 +671,7 @@ def __init__(
device=self.storing_device,
),
)
self._tensordict_out.refine_names(..., "time")

if split_trajs is None:
split_trajs = False
Expand Down
107 changes: 103 additions & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl._utils import prod, seed_generator

from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
)

from .._utils import prod, seed_generator
from ..data.utils import DEVICE_TYPING
from .utils import get_available_libraries, step_mdp
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import get_available_libraries, step_mdp

LIBRARIES = get_available_libraries()

Expand Down Expand Up @@ -219,6 +219,13 @@ def batch_size(self) -> TensorSpec:
def batch_size(self, value: torch.Size) -> None:
self._batch_size = torch.Size(value)

def ndimension(self):
return len(self.batch_size)

@property
def ndim(self):
return self.ndimension()

# Parent specs: input and output spec.
@property
def input_spec(self) -> TensorSpec:
Expand Down Expand Up @@ -661,6 +668,97 @@ def rollout(
Returns:
TensorDict object containing the resulting trajectory.
The data returned will be marked with a "time" dimension name for the last
dimension of the tensordict (at the ``env.ndim`` index).
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import TransformedEnv, StepCounter
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))
>>> rollout = env.rollout(max_steps=1000)
>>> print(rollout)
TensorDict(
fields={
action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([20]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([20]),
device=cpu,
is_shared=False)
>>> print(rollout.names)
['time']
>>> # with envs that contain more dimensions
>>> from torchrl.envs import SerialEnv
>>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)))
>>> rollout = env.rollout(max_steps=1000)
>>> print(rollout)
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 20]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 20]),
device=cpu,
is_shared=False)
>>> print(rollout.names)
[None, 'time']
In some instances, contiguous tensordict cannot be obtained because
they cannot be stacked. This can happen when the data returned at
each step may have a different shape, or when different environments
are executed together. In that case, ``return_contiguous=False``
will cause the returned tensordict to be a lazy stack of tensordicts:
Examples:
>>> rollout = env.rollout(4, return_contiguous=False)
>>> print(rollout)
LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: LazyStackedTensorDict(
fields={
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=cpu,
is_shared=False)
>>> print(rollout.names)
[None, 'time']
"""
try:
policy_device = next(policy.parameters()).device
Expand Down Expand Up @@ -718,6 +816,7 @@ def policy(td):
batch_size = self.batch_size if tensordict is None else tensordict.batch_size

out_td = torch.stack(tensordicts, len(batch_size))
out_td.refine_names(..., "time")
if return_contiguous:
return out_td.contiguous()
return out_td
Expand Down

0 comments on commit 99a95e3

Please sign in to comment.