diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index 6e5dee781e8..2c93a5748ef 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -218,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict: turn = state["turn"].clone() action = state["action"] board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1) - wins = self.win(state["board"], action) + wins = self.win(board, action) mask = board.flatten(-2, -1) == -1 done = wins | ~mask.any(-1, keepdim=True) @@ -234,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict: ("player0", "reward"): reward_0.float(), ("player1", "reward"): reward_1.float(), "board": torch.where(board == -1, board, 1 - board), - "turn": 1 - state["turn"], + "turn": 1 - turn, "mask": mask, }, batch_size=state.batch_size, @@ -260,13 +260,15 @@ def _set_seed(self, seed: int | None): def win(board: torch.Tensor, action: torch.Tensor): row = action // 3 # type: ignore col = action % 3 # type: ignore - return ( - board[..., row, :].sum() - == 3 | board[..., col].sum() - == 3 | board.diagonal(0, -2, -1).sum() - == 3 | board.flip(-1).diagonal(0, -2, -1).sum() - == 3 - ) + if board[..., row, :].sum() == 3: + return True + if board[..., col].sum() == 3: + return True + if board.diagonal(0, -2, -1).sum() == 3: + return True + if board.flip(-1).diagonal(0, -2, -1).sum() == 3: + return True + return False @staticmethod def full(board: torch.Tensor) -> bool: