From e28440f68cf78afd93d5a7e5fd1aba7a33c18efb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Jul 2024 14:47:01 +0100 Subject: [PATCH] [Feature] TicTacToeEnv ghstack-source-id: a48a28b322af074877e9a66261310ba41a9599f0 Pull Request resolved: https://github.com/pytorch/rl/pull/2301 --- docs/source/reference/envs.rst | 12 ++ test/test_env.py | 13 ++ test/test_specs.py | 12 +- torchrl/data/tensor_specs.py | 2 +- torchrl/envs/__init__.py | 1 + torchrl/envs/common.py | 15 +- torchrl/envs/custom/__init__.py | 6 + torchrl/envs/custom/tictactoeenv.py | 281 ++++++++++++++++++++++++++++ 8 files changed, 330 insertions(+), 12 deletions(-) create mode 100644 torchrl/envs/custom/__init__.py create mode 100644 torchrl/envs/custom/tictactoeenv.py diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c7b0eba35c0..023d93738cd 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -335,6 +335,18 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w ParallelEnv EnvCreator + +Custom native TorchRL environments +---------------------------------- + +TorchRL offers a series of custom built-in environments. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TicTacToeEnv + Multi-agent environments ------------------------ diff --git a/test/test_env.py b/test/test_env.py index 32e9ffccb55..e151ddaae0c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -80,6 +80,7 @@ EnvCreator, ParallelEnv, SerialEnv, + TicTacToeEnv, ) from torchrl.envs.batched_envs import _stackable from torchrl.envs.gym_like import default_info_dict_reader @@ -3307,6 +3308,18 @@ def test_partial_rest(self, batched): assert s["next", "string"] == ["6", "6"] +class TestCustomEnvs: + def test_tictactoe(self): + torch.manual_seed(0) + env = TicTacToeEnv() + check_env_specs(env) + for _ in range(10): + r = env.rollout(10) + assert r.shape[-1] < 10 + r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + assert r.shape[-1] < 10 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_specs.py b/test/test_specs.py index 6b779811f1d..2d597d770f0 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3013,7 +3013,9 @@ def test_repr(self): space=None, device=cpu, dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([3])), + domain=continuous), + device=cpu, + shape=torch.Size([3])), 1 -> lidar: BoundedTensorSpec( shape=torch.Size([20]), @@ -3031,7 +3033,9 @@ def test_repr(self): high=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([3])), + domain=continuous), + device=cpu, + shape=torch.Size([3])), 2 -> individual_2_obs: CompositeSpec( individual_1_obs_0: UnboundedContinuousTensorSpec( @@ -3039,7 +3043,9 @@ def test_repr(self): space=None, device=cpu, dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([3]))}}, + domain=continuous), + device=cpu, + shape=torch.Size([3]))}}, device=cpu, shape={torch.Size((3,))}, stack_dim={c.stack_dim})""" diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ae5b58a06a0..7c787b3ccfc 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4100,7 +4100,7 @@ def __repr__(self) -> str: indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) - return f"CompositeSpec(\n{sub_str}, device={self._device}, shape={self.shape})" + return f"CompositeSpec(\n{sub_str},\n device={self._device},\n shape={self.shape})" def type_check( self, diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 8475979a3ba..748bef78d0b 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,6 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict +from .custom import TicTacToeEnv from .env_creator import EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index eaf701fde34..b9216b58e86 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2355,10 +2355,13 @@ def rollout( break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is called on the sub-envs that are done. Default is True. return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. - tensordict (TensorDict, optional): if auto_reset is False, an initial + tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the - environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the - output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout. + environment in those dimensions (if needed). + This normally should not occur if ``tensordict`` is the output of a reset, but can occur + if ``tensordict`` is the last step of a previous rollout. + A ``tensordict`` can also be provided when ``auto_reset=True`` if metadata need to be passed + to the ``reset`` method, such as a batch-size or a device for stateless environments. set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to ``True`` after completion of the rollout. If no ``"truncated"`` is found within the ``done_spec``, an exception is raised. @@ -2565,11 +2568,7 @@ def rollout( env_device = self.device if auto_reset: - if tensordict is not None: - raise RuntimeError( - "tensordict cannot be provided when auto_reset is True" - ) - tensordict = self.reset() + tensordict = self.reset(tensordict) elif tensordict is None: raise RuntimeError("tensordict must be provided when auto_reset is False") else: diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py new file mode 100644 index 00000000000..c56a5ee5128 --- /dev/null +++ b/torchrl/envs/custom/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py new file mode 100644 index 00000000000..a46819cab17 --- /dev/null +++ b/torchrl/envs/custom/tictactoeenv.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import Optional + +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl.data.tensor_specs import ( + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) +from torchrl.envs.common import EnvBase + + +class TicTacToeEnv(EnvBase): + """A Tic-Tac-Toe implementation. + + At each turn, one of the two players have to play. + + The environment is stateless. To run it across multiple batches, call + + >>> env.reset(TensorDict(batch_size=desired_batch_size)) + + If the ``"mask"`` entry is present, ``rand_action`` takes it into account to + generate the next action. Any policy executed on this env should take this + mask into account, as well as the turn of the player (stored in the ``"turn"`` + output entry). + + Specs: + CompositeSpec( + output_spec: CompositeSpec( + full_observation_spec: CompositeSpec( + board: DiscreteTensorSpec( + shape=torch.Size([3, 3]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + turn: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + mask: DiscreteTensorSpec( + shape=torch.Size([9]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + shape=torch.Size([])), + full_reward_spec: CompositeSpec( + player0: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + player1: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + shape=torch.Size([])), + full_done_spec: CompositeSpec( + done: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + terminated: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + truncated: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + shape=torch.Size([])), + shape=torch.Size([])), + input_spec: CompositeSpec( + full_state_spec: CompositeSpec( + board: DiscreteTensorSpec( + shape=torch.Size([3, 3]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + turn: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + mask: DiscreteTensorSpec( + shape=torch.Size([9]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), shape=torch.Size([])), + full_action_spec: CompositeSpec( + action: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=9), + dtype=torch.int64, + domain=discrete), + shape=torch.Size([])), + shape=torch.Size([])), + shape=torch.Size([])) + + To run a dummy rollout, execute the following command: + + Examples: + >>> env = TicTacToeEnv() + >>> env.rollout(10) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int64, is_shared=False), + board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False), + player0: TensorDict( + fields={ + reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False), + player1: TensorDict( + fields={ + reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False) + + """ + + # batch_locked is set to False since various batch sizes can be provided to the env + batch_locked: bool = False + + def __init__(self, device=None): + super().__init__() + self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec( + n=9, + shape=(), + device=device, + ) + + self.full_observation_spec: CompositeSpec = CompositeSpec( + board=UnboundedContinuousTensorSpec( + shape=(3, 3), dtype=torch.int, device=device + ), + turn=DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.int, + device=device, + ), + mask=DiscreteTensorSpec( + 2, + shape=(9,), + dtype=torch.bool, + device=device, + ), + device=device, + ) + self.state_spec: CompositeSpec = self.observation_spec.clone() + + self.reward_spec: UnboundedContinuousTensorSpec = CompositeSpec( + { + ("player0", "reward"): UnboundedContinuousTensorSpec( + shape=(1,), device=device + ), + ("player1", "reward"): UnboundedContinuousTensorSpec( + shape=(1,), device=device + ), + }, + device=device, + ) + + self.full_done_spec: DiscreteTensorSpec = CompositeSpec( + done=DiscreteTensorSpec(2, shape=(1,), dtype=torch.bool, device=device), + device=device, + ) + self.full_done_spec["terminated"] = self.full_done_spec["done"].clone() + self.full_done_spec["truncated"] = self.full_done_spec["done"].clone() + + def _reset(self, reset_td: TensorDict) -> TensorDict: + shape = reset_td.shape if reset_td is not None else () + state = self.state_spec.zero(shape) + state["board"] -= 1 + state["mask"].fill_(True) + return state.update(self.full_done_spec.zero(shape)) + + def _step(self, state: TensorDict) -> TensorDict: + + board = state["board"].clone() + turn = state["turn"].clone() + action = state["action"] + board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1) + wins = self.win(state["board"], action) + + mask = board.flatten(-2, -1) == -1 + done = wins | ~mask.any(-1, keepdim=True) + terminated = done.clone() + + reward_0 = wins & (turn == 0) + reward_1 = wins & (turn == 1) + + state = TensorDict( + { + "done": done, + "terminated": terminated, + ("player0", "reward"): reward_0.float(), + ("player1", "reward"): reward_1.float(), + "board": torch.where(board == -1, board, 1 - board), + "turn": 1 - state["turn"], + "mask": mask, + }, + batch_size=state.batch_size, + ) + return state + + def _set_seed(self, seed: int | None): + ... + + @staticmethod + def win(board: torch.Tensor, action: torch.Tensor): + row = action // 3 # type: ignore + col = action % 3 # type: ignore + return ( + board[..., row, :].sum() + == 3 | board[..., col].sum() + == 3 | board.diagonal(0, -2, -1).sum() + == 3 | board.flip(-1).diagonal(0, -2, -1).sum() + == 3 + ) + + @staticmethod + def full(board: torch.Tensor) -> bool: + return torch.sym_int(board.abs().sum()) == 9 + + @staticmethod + def get_action_mask(): + pass + + def rand_action(self, tensordict: Optional[TensorDictBase] = None): + mask = tensordict.get("mask") + action_spec = self.action_spec + if tensordict.ndim: + action_spec = action_spec.expand(tensordict.shape) + else: + action_spec = action_spec.clone() + action_spec.update_mask(mask) + tensordict.set(self.action_key, action_spec.rand()) + return tensordict