Skip to content

Commit

Permalink
[Feature] single player for TicTacToe
Browse files Browse the repository at this point in the history
ghstack-source-id: 18fb525b49fb8b3b9aef8367a54070eb70fbb3c0
Pull Request resolved: pytorch#2303
  • Loading branch information
vmoens committed Jul 22, 2024
1 parent 04d8b52 commit 59c3374
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
18 changes: 18 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3319,6 +3319,24 @@ def test_tictactoe_env(self):
assert r.shape[-1] < 10
r = env.rollout(10, tensordict=TensorDict(batch_size=[5]))
assert r.shape[-1] < 10
r = env.rollout(
100, tensordict=TensorDict(batch_size=[5]), break_when_any_done=False
)
assert r.shape == (5, 100)

def test_tictactoe_env_single(self):
torch.manual_seed(0)
env = TicTacToeEnv(single_player=True)
check_env_specs(env)
for _ in range(10):
r = env.rollout(10)
assert r.shape[-1] < 6
r = env.rollout(10, tensordict=TensorDict(batch_size=[5]))
assert r.shape[-1] < 6
r = env.rollout(
100, tensordict=TensorDict(batch_size=[5]), break_when_any_done=False
)
assert r.shape == (5, 100)

def test_pendulum_env(self):
env = PendulumEnv(device=None)
Expand Down
26 changes: 22 additions & 4 deletions torchrl/envs/custom/tictactoeenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
class TicTacToeEnv(EnvBase):
"""A Tic-Tac-Toe implementation.
At each turn, one of the two players have to play.
Keyword Args:
single_player (bool, optional): whether one or two players have to be
accounted for. ``single_player=True`` means that ``"player1"`` is
playing randomly. If ``False`` (default), at each turn,
one of the two players has to play.
device (torch.device, optional): the device where to put the tensors.
Defaults to ``None`` (default device).
The environment is stateless. To run it across multiple batches, call
Expand Down Expand Up @@ -163,8 +169,9 @@ class TicTacToeEnv(EnvBase):
# batch_locked is set to False since various batch sizes can be provided to the env
batch_locked: bool = False

def __init__(self, device=None):
super().__init__()
def __init__(self, *, single_player: bool = False, device=None):
super().__init__(device=device)
self.single_player = single_player
self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec(
n=9,
shape=(),
Expand Down Expand Up @@ -218,7 +225,6 @@ def _reset(self, reset_td: TensorDict) -> TensorDict:
return state.update(self.full_done_spec.zero(shape))

def _step(self, state: TensorDict) -> TensorDict:

board = state["board"].clone()
turn = state["turn"].clone()
action = state["action"]
Expand All @@ -244,6 +250,18 @@ def _step(self, state: TensorDict) -> TensorDict:
},
batch_size=state.batch_size,
)
if self.single_player:
select = (~done & (turn == 0)).squeeze(-1)
if select.all():
state_select = state
elif select.any():
state_select = state[select]
else:
return state
state_select = self._step(self.rand_action(state_select))
if select.all():
return state_select
return torch.where(done, state, state_select)
return state

def _set_seed(self, seed: int | None):
Expand Down

0 comments on commit 59c3374

Please sign in to comment.