Skip to content

Commit

Permalink
[Feature] TicTacToeEnv
Browse files Browse the repository at this point in the history
ghstack-source-id: a48a28b322af074877e9a66261310ba41a9599f0
Pull Request resolved: #2301
  • Loading branch information
vmoens committed Jul 22, 2024
1 parent 87f66e8 commit e28440f
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 12 deletions.
12 changes: 12 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,18 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w
ParallelEnv
EnvCreator


Custom native TorchRL environments
----------------------------------

TorchRL offers a series of custom built-in environments.

.. autosummary::
:toctree: generated/
:template: rl_template.rst

TicTacToeEnv

Multi-agent environments
------------------------

Expand Down
13 changes: 13 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
EnvCreator,
ParallelEnv,
SerialEnv,
TicTacToeEnv,
)
from torchrl.envs.batched_envs import _stackable
from torchrl.envs.gym_like import default_info_dict_reader
Expand Down Expand Up @@ -3307,6 +3308,18 @@ def test_partial_rest(self, batched):
assert s["next", "string"] == ["6", "6"]


class TestCustomEnvs:
def test_tictactoe(self):
torch.manual_seed(0)
env = TicTacToeEnv()
check_env_specs(env)
for _ in range(10):
r = env.rollout(10)
assert r.shape[-1] < 10
r = env.rollout(10, tensordict=TensorDict(batch_size=[5]))
assert r.shape[-1] < 10


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
12 changes: 9 additions & 3 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,7 +3013,9 @@ def test_repr(self):
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([3])),
domain=continuous),
device=cpu,
shape=torch.Size([3])),
1 ->
lidar: BoundedTensorSpec(
shape=torch.Size([20]),
Expand All @@ -3031,15 +3033,19 @@ def test_repr(self):
high=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([3])),
domain=continuous),
device=cpu,
shape=torch.Size([3])),
2 ->
individual_2_obs: CompositeSpec(
individual_1_obs_0: UnboundedContinuousTensorSpec(
shape=torch.Size([3, 1, 2, 3]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([3]))}},
domain=continuous),
device=cpu,
shape=torch.Size([3]))}},
device=cpu,
shape={torch.Size((3,))},
stack_dim={c.stack_dim})"""
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4100,7 +4100,7 @@ def __repr__(self) -> str:
indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
]
sub_str = ",\n".join(sub_str)
return f"CompositeSpec(\n{sub_str}, device={self._device}, shape={self.shape})"
return f"CompositeSpec(\n{sub_str},\n device={self._device},\n shape={self.shape})"

def type_check(
self,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .custom import TicTacToeEnv
from .env_creator import EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs import (
Expand Down
15 changes: 7 additions & 8 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,10 +2355,13 @@ def rollout(
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
called on the sub-envs that are done. Default is True.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if auto_reset is False, an initial
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the
output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout.
environment in those dimensions (if needed).
This normally should not occur if ``tensordict`` is the output of a reset, but can occur
if ``tensordict`` is the last step of a previous rollout.
A ``tensordict`` can also be provided when ``auto_reset=True`` if metadata need to be passed
to the ``reset`` method, such as a batch-size or a device for stateless environments.
set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to
``True`` after completion of the rollout. If no ``"truncated"`` is found within the
``done_spec``, an exception is raised.
Expand Down Expand Up @@ -2565,11 +2568,7 @@ def rollout(
env_device = self.device

if auto_reset:
if tensordict is not None:
raise RuntimeError(
"tensordict cannot be provided when auto_reset is True"
)
tensordict = self.reset()
tensordict = self.reset(tensordict)
elif tensordict is None:
raise RuntimeError("tensordict must be provided when auto_reset is False")
else:
Expand Down
6 changes: 6 additions & 0 deletions torchrl/envs/custom/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .tictactoeenv import TicTacToeEnv
Loading

0 comments on commit e28440f

Please sign in to comment.