-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
1,086 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Pipfile | ||
*.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# python package requirements | ||
include requirements.txt | ||
|
||
# Sprites | ||
recursive-include pygame_spiel/images * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Pygame_spiel: Pygame-based UI to play board Reinforcement Learning games from OpenSpiel | ||
PygameSpiel is a [Pygame](https://www.pygame.org)-based library to play board games from the library [OpenSpiel](https://github.com/google-deepmind/open_spiel) against AI algorithms. | ||
|
||
## Install | ||
```bash | ||
pip install pygame_spiel | ||
``` | ||
|
||
## Version 0.1.0 | ||
Games currently available: | ||
* Breakthrough | ||
* Tic Tac Toe | ||
|
||
AI algorithms available: | ||
* mcts, DQN (currently only for breakthrough) | ||
|
||
**more to come...** | ||
|
||
## Overview | ||
Run Pygame_spiel with: | ||
|
||
```bash | ||
pygame_spiel | ||
``` | ||
|
||
|
||
|
||
Use your mouse to select the cell (tic tac toe) or select pawn and destination cell (breakthrough). | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pygame | ||
import pygame_menu | ||
from pygame_menu import themes | ||
|
||
from pygame_spiel.games.settings import GAMES_BOTS | ||
from pygame_spiel.games.factory import GameFactory | ||
from pygame_spiel.menu import Menu | ||
from pygame_spiel.main import pygame_spiel | ||
|
||
def select_game(game, index): | ||
pass | ||
|
||
|
||
def main(): | ||
pygame_spiel() | ||
menu = Menu() | ||
menu.display() | ||
game_name = menu.get_selected_game() | ||
bot_type = menu.get_selected_opponent() | ||
|
||
player_id = 0 | ||
|
||
assert ( | ||
bot_type in GAMES_BOTS[game_name].keys() | ||
), f"""Bot type {bot_type} not available for game {game_name}. List of | ||
available bots: {list(GAMES_BOTS[game_name].keys())}""" | ||
|
||
game = GameFactory.get_game(game_name, current_player=player_id) | ||
game.set_bots( | ||
bot1_type='human', | ||
bot1_params=None, | ||
bot2_type=bot_type, | ||
bot2_params=None, | ||
) | ||
|
||
done = False | ||
clock = pygame.time.Clock() | ||
|
||
while not done: | ||
clock.tick(10) | ||
|
||
events = pygame.event.get() | ||
for event in events: | ||
if event.type == pygame.QUIT: | ||
done = True | ||
|
||
mouse_pos = pygame.mouse.get_pos() | ||
mouse_pressed = pygame.mouse.get_pressed() | ||
|
||
game.play(mouse_pos=mouse_pos, mouse_pressed=mouse_pressed) | ||
|
||
pygame.display.flip() | ||
|
||
|
||
if __name__ == "__main__": | ||
pygame_spiel() | ||
# main() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import os | ||
import tensorflow.compat.v1 as tf | ||
|
||
from open_spiel.python import rl_environment | ||
from open_spiel.python.algorithms.dqn import DQN | ||
|
||
import pyspiel | ||
|
||
|
||
class DQNBot(pyspiel.Bot): | ||
"""Bot that uses DQN algorithm.""" | ||
|
||
def __init__( | ||
self, | ||
game, | ||
player_id, | ||
replay_buffer_capacity=int(1e5), | ||
batch_size=32, | ||
checkpoint_dir=None, | ||
): | ||
"""Initializes a DQN algorithm in the form of a bot. | ||
Args: | ||
game: A pyspiel.Game to play. | ||
player_id: ID associated to the player. | ||
replay_buffer_capacity: Replay buffer size | ||
batch_size: Training batch size (not used in this Bot yet) | ||
""" | ||
|
||
pyspiel.Bot.__init__(self) | ||
|
||
self._num_players = game.num_players() | ||
self._hidden_layer_sizes = [64, 64] # TODO add parameter in constructor | ||
self._env = rl_environment.Environment(game) | ||
info_state_size = self._env.observation_spec()["info_state"][0] | ||
num_actions = self._env.action_spec()["num_actions"] | ||
self._time_step = self._env.reset() | ||
|
||
self._sess = tf.Session() | ||
hidden_layers_sizes = [int(l) for l in self._hidden_layer_sizes] | ||
|
||
self._agent = DQN( | ||
session=self._sess, | ||
player_id=player_id, | ||
state_representation_size=info_state_size, | ||
num_actions=num_actions, | ||
hidden_layers_sizes=hidden_layers_sizes, | ||
replay_buffer_capacity=replay_buffer_capacity, | ||
batch_size=batch_size, | ||
) | ||
# self._agent.restore(checkpoint_dir) | ||
if checkpoint_dir is not None: | ||
if not os.path.exists(checkpoint_dir): | ||
raise FileNotFoundError("No folder exists at the location specified") | ||
self._agent.restore(checkpoint_dir) | ||
else: | ||
self._sess.run(tf.global_variables_initializer()) | ||
|
||
def restart_at(self, state): | ||
pass | ||
|
||
def step(self, state): | ||
"""Returns bot's action at given state.""" | ||
|
||
# Next lines taken from https://github.com/deepmind/open_spiel/issues/896 | ||
player_id = state.current_player() | ||
legal_actions = [ | ||
state.legal_actions(player_id) for _ in range(self._num_players) | ||
] | ||
info_state = [ | ||
state.observation_tensor(player_id) for _ in range(self._num_players) | ||
] | ||
step_type = ( | ||
rl_environment.StepType.LAST | ||
if state.is_terminal() | ||
else rl_environment.StepType.MID | ||
) | ||
time_step = rl_environment.TimeStep( | ||
observations={ | ||
"info_state": info_state, | ||
"legal_actions": legal_actions, | ||
"current_player": player_id, | ||
}, | ||
rewards=state.rewards(), | ||
discounts=[1.0, 1.0], | ||
step_type=step_type, | ||
) | ||
|
||
agent_output = self._agent.step(time_step, is_evaluation=True) | ||
action = ( | ||
agent_output.action | ||
) # TODO expand functionality to simultaneous games with apply_actions() | ||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
import tensorflow.compat.v1 as tf | ||
|
||
from open_spiel.python import rl_environment | ||
from open_spiel.python.pytorch import dqn as dqn_pt | ||
|
||
import pyspiel | ||
|
||
|
||
class DQNBot(pyspiel.Bot): | ||
"""Bot that uses DQN algorithm.""" | ||
|
||
def __init__( | ||
self, | ||
game, | ||
player_id, | ||
replay_buffer_capacity=int(1e5), | ||
batch_size=32, | ||
checkpoint_dir=None, | ||
): | ||
"""Initializes a DQN algorithm in the form of a bot. | ||
Args: | ||
game: A pyspiel.Game to play. | ||
player_id: ID associated to the player. | ||
replay_buffer_capacity: Replay buffer size | ||
batch_size: Training batch size (not used in this Bot yet) | ||
""" | ||
|
||
pyspiel.Bot.__init__(self) | ||
|
||
self._num_players = game.num_players() | ||
self._hidden_layer_sizes = [64, 64] # TODO add parameter in constructor | ||
self._env = rl_environment.Environment(game) | ||
info_state_size = self._env.observation_spec()["info_state"][0] | ||
num_actions = self._env.action_spec()["num_actions"] | ||
self._time_step = self._env.reset() | ||
|
||
self._sess = tf.Session() | ||
hidden_layers_sizes = [int(l) for l in self._hidden_layer_sizes] | ||
|
||
self._agent = dqn_pt.DQN( | ||
player_id=player_id, | ||
state_representation_size=info_state_size, | ||
num_actions=num_actions, | ||
hidden_layers_sizes=hidden_layers_sizes, | ||
replay_buffer_capacity=replay_buffer_capacity, | ||
batch_size=batch_size, | ||
) | ||
# self._agent.restore(checkpoint_dir) | ||
if checkpoint_dir is not None: | ||
if not os.path.exists(checkpoint_dir): | ||
raise FileNotFoundError("No folder exists at the location specified") | ||
self._agent.restore(checkpoint_dir) | ||
else: | ||
self._sess.run(tf.global_variables_initializer()) | ||
|
||
def restart_at(self, state): | ||
pass | ||
|
||
def step(self, state): | ||
"""Returns bot's action at given state.""" | ||
|
||
# Next lines taken from https://github.com/deepmind/open_spiel/issues/896 | ||
player_id = state.current_player() | ||
legal_actions = [ | ||
state.legal_actions(player_id) for _ in range(self._num_players) | ||
] | ||
info_state = [ | ||
state.observation_tensor(player_id) for _ in range(self._num_players) | ||
] | ||
step_type = ( | ||
rl_environment.StepType.LAST | ||
if state.is_terminal() | ||
else rl_environment.StepType.MID | ||
) | ||
time_step = rl_environment.TimeStep( | ||
observations={ | ||
"info_state": info_state, | ||
"legal_actions": legal_actions, | ||
"current_player": player_id, | ||
}, | ||
rewards=state.rewards(), | ||
discounts=[1.0, 1.0], | ||
step_type=step_type, | ||
) | ||
|
||
agent_output = self._agent.step(time_step, is_evaluation=True) | ||
action = ( | ||
agent_output.action | ||
) # TODO expand functionality to simultaneous games with apply_actions() | ||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import sys | ||
import importlib | ||
from pygame_spiel.games.tic_tac_toe import TicTacToe | ||
from pygame_spiel.games.breakthrough import Breakthrough |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import abc | ||
import pygame | ||
import pyspiel | ||
import typing as t | ||
import site | ||
from pathlib import Path | ||
import os | ||
|
||
from pygame_spiel.games.settings import SCREEN_SIZE, BREAKPOINTS_DRIVE_IDS | ||
from pygame_spiel.utils import init_bot, download_weights | ||
|
||
|
||
class Game(metaclass=abc.ABCMeta): | ||
def __init__(self, name, current_player): | ||
self._name = name | ||
self._current_player = current_player | ||
|
||
# Initialise game | ||
self._game = pyspiel.load_game(name) | ||
self._state = self._game.new_initial_state() | ||
self._state_string = self._state.to_string() | ||
pygame.init() | ||
|
||
self._screen = pygame.display.set_mode(SCREEN_SIZE[name]) | ||
pygame.display.set_caption(name) | ||
|
||
self._package_path = site.getsitepackages()[0] | ||
|
||
@abc.abstractmethod | ||
def play( | ||
self, mouse_pos: t.Tuple[int, int], mouse_pressed: t.Tuple[bool, bool, bool] | ||
) -> None: | ||
""" | ||
Abstact interface of the function play(). At each iteration, it requires the mouse position | ||
and state (which button was pressed, if any). | ||
Parameters: | ||
mouse_pos (tuple): Position of the mouse (X,Y coordinates) | ||
mouse_pressed (tuple): 1 if the i-th button is pressed | ||
""" | ||
|
||
def set_bots( | ||
self, bot1_type: str, bot1_params: str, bot2_type: str, bot2_params: str | ||
) -> None: | ||
""" | ||
Set a Bot for each player. Available bots are: random, human, mcts, dqn. | ||
Only 2-players game currently supported (so only two bots are set) | ||
Parameters: | ||
bot1_type (str): Bot type of player 0 | ||
bot1_params (str): Bot's parameters (e.g., neural network breakpoints) | ||
bot2_type (str): Bot type of player 1 | ||
bot2_params (str): Bot's parameters (e.g., neural network breakpoints) | ||
""" | ||
# TODO self._bot_params is not used. Remove | ||
|
||
self._bot_params = [bot1_params, bot2_params] | ||
self._bots = [] | ||
|
||
for i, bot_type in enumerate([bot1_type, bot2_type]): | ||
bot_breakpoint_dir = None | ||
if bot_type in ["dqn"]: | ||
breakpoint_dest_dir = Path( | ||
self._package_path, | ||
"pygame_spiel/data/breakpoints", | ||
bot_type, | ||
self._name, | ||
) | ||
file_id = BREAKPOINTS_DRIVE_IDS[self._name][bot_type] | ||
print(breakpoint_dest_dir) | ||
if not os.path.exists(breakpoint_dest_dir): | ||
print( | ||
f"Downloading breakpoints for bot {bot_type} and game {self._name}" | ||
) | ||
download_weights( | ||
file_id=file_id, dest_folder=str(breakpoint_dest_dir) | ||
) | ||
bot_breakpoint_dir = Path(breakpoint_dest_dir, "weights_default") | ||
bot = init_bot( | ||
bot_type, self._game, player_id=i, breakpoint_dir=bot_breakpoint_dir | ||
) | ||
self._bots.append(bot) |
Oops, something went wrong.