diff --git a/test/test_env.py b/test/test_env.py index dcf53af3765..dee03c06e7d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3319,6 +3319,24 @@ def test_tictactoe_env(self): assert r.shape[-1] < 10 r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) assert r.shape[-1] < 10 + r = env.rollout( + 100, tensordict=TensorDict(batch_size=[5]), break_when_any_done=False + ) + assert r.shape == (5, 100) + + def test_tictactoe_env_single(self): + torch.manual_seed(0) + env = TicTacToeEnv(single_player=True) + check_env_specs(env) + for _ in range(10): + r = env.rollout(10) + assert r.shape[-1] < 6 + r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + assert r.shape[-1] < 6 + r = env.rollout( + 100, tensordict=TensorDict(batch_size=[5]), break_when_any_done=False + ) + assert r.shape == (5, 100) def test_pendulum_env(self): env = PendulumEnv(device=None) diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index a46819cab17..79ea3b2dfb6 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -21,7 +21,13 @@ class TicTacToeEnv(EnvBase): """A Tic-Tac-Toe implementation. - At each turn, one of the two players have to play. + Keyword Args: + single_player (bool, optional): whether one or two players have to be + accounted for. ``single_player=True`` means that ``"player1"`` is + playing randomly. If ``False`` (default), at each turn, + one of the two players has to play. + device (torch.device, optional): the device where to put the tensors. + Defaults to ``None`` (default device). The environment is stateless. To run it across multiple batches, call @@ -163,8 +169,9 @@ class TicTacToeEnv(EnvBase): # batch_locked is set to False since various batch sizes can be provided to the env batch_locked: bool = False - def __init__(self, device=None): - super().__init__() + def __init__(self, *, single_player: bool = False, device=None): + super().__init__(device=device) + self.single_player = single_player self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec( n=9, shape=(), @@ -218,7 +225,6 @@ def _reset(self, reset_td: TensorDict) -> TensorDict: return state.update(self.full_done_spec.zero(shape)) def _step(self, state: TensorDict) -> TensorDict: - board = state["board"].clone() turn = state["turn"].clone() action = state["action"] @@ -244,6 +250,18 @@ def _step(self, state: TensorDict) -> TensorDict: }, batch_size=state.batch_size, ) + if self.single_player: + select = (~done & (turn == 0)).squeeze(-1) + if select.all(): + state_select = state + elif select.any(): + state_select = state[select] + else: + return state + state_select = self._step(self.rand_action(state_select)) + if select.all(): + return state_select + return torch.where(done, state, state_select) return state def _set_seed(self, seed: int | None):