Skip to content

Commit

Permalink
Adds visualization of game trees using graphviz.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 270871786
Change-Id: Id299c8d8ed3019eebcd6d525cf53fd7b543221b8
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Sep 26, 2019
1 parent b3e58ee commit 5898968
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 0 deletions.
81 changes: 81 additions & 0 deletions open_spiel/python/examples/treeviz_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Game tree visualization example."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging

import pyspiel
from open_spiel.python.visualizations import treeviz

FLAGS = flags.FLAGS
flags.DEFINE_string("game", "kuhn_poker", "Name of the game")
flags.DEFINE_string("out", "/tmp/gametree.png", "Name of output file, e.g., "
"[*.png|*.pdf].")
flags.DEFINE_enum("prog", "dot", ["dot", "neato", "circo"], "Graphviz layout.")
flags.DEFINE_boolean("group_infosets", False, "Wether to group infosets.")
flags.DEFINE_boolean("group_terminal", False, "Wether to group terminal nodes.")
flags.DEFINE_boolean("verbose", False, "Wether to print verbose output.")


def _zero_sum_node_decorator(state):
"""Custom node decorator that only shows the return of the first player."""
attrs = treeviz.default_node_decorator(state) # get default attributes
if state.is_terminal():
attrs["label"] = str(int(state.returns()[0]))
return attrs


def main(argv):
del argv

game = pyspiel.load_game(FLAGS.game)
game_type = game.get_type()

if game_type.dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:
logging.warn("%s is not turn-based. Trying to reload game as turn-based.",
FLAGS.game)
game = pyspiel.load_game_as_turn_based(FLAGS.game)
game_type = game.get_type()

if game_type.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL:
raise ValueError("Game must be sequential, not {}".format(
game_type.dynamics))

if (game_type.utility == pyspiel.GameType.Utility.ZERO_SUM and
game.num_players() == 2):
logging.info("Game is zero-sum: only showing first-player's returns.")
gametree = treeviz.GameTree(
game,
node_decorator=_zero_sum_node_decorator,
group_infosets=FLAGS.group_infosets,
group_terminal=FLAGS.group_terminal)
else:
gametree = treeviz.GameTree(game) # use default decorators

if FLAGS.verbose:
logging.info("Game tree:\n%s", gametree.to_string())

gametree.draw(FLAGS.out, prog=FLAGS.prog)
logging.info("Game tree saved to file: %s", FLAGS.out)


if __name__ == "__main__":
app.run(main)
14 changes: 14 additions & 0 deletions open_spiel/python/visualizations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

211 changes: 211 additions & 0 deletions open_spiel/python/visualizations/treeviz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Visualizing game trees with graphviz.
GameTree builds a `pygraphviz.AGraph` reprensentation of the game tree. The
resulting tree can be directly visualized in Jupyter notebooks or Google Colab
via SVG plotting - or written to a file by calling `draw(filename, prog="dot")`.
See `examples/treeviz_example.py` for a more detailed example.
This module relies on external dependencies, which need to be installed before
use. On a debian system follow these steps:
```
sudo apt-get install graphviz libgraphviz-dev
pip install pygraphviz
```
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

# pylint: disable=g-import-not-at-top
try:
import pygraphviz
except (ImportError, Exception) as e:
raise ImportError(
str(e) + "\nPlease make sure to install the following dependencies:\n"
"sudo apt-get install graphviz libgraphviz-dev\n"
"pip install pygraphviz")
# pylint: enable=g-import-not-at-top

_PLAYER_SHAPES = {0: "square", 1: "ellipse"}
_PLAYER_COLORS = {-1: "black", 0: "blue", 1: "red"}
_FONTSIZE = 8
_WIDTH = _HEIGHT = 0.25
_ARROWSIZE = .5
_MARGIN = 0.01


def default_node_decorator(state):
"""Decorates a state-node of the game tree.
This method can be called by a custom decorator to prepopulate the attributes
dictionary. Then only relevant attributes need to be changed, or added.
Args:
state: The state.
Returns:
`dict` with graphviz node style attributes.
"""
player = state.current_player()
attrs = {
"label": "",
"fontsize": _FONTSIZE,
"width": _WIDTH,
"height": _HEIGHT,
"margin": _MARGIN
}
if state.is_terminal():
attrs["label"] = ", ".join(map(str, state.returns()))
attrs["shape"] = "diamond"
elif state.is_chance_node():
attrs["shape"] = "point"
attrs["width"] = _WIDTH / 2.
attrs["height"] = _HEIGHT / 2.
else:
attrs["label"] = str(state.information_state())
attrs["shape"] = _PLAYER_SHAPES.get(player, "ellipse")
attrs["color"] = _PLAYER_COLORS.get(player, "black")
return attrs


def default_edge_decorator(parent, unused_child, action):
"""Decorates a state-node of the game tree.
This method can be called by a custom decorator to prepopulate the attributes
dictionary. Then only relevant attributes need to be changed, or added.
Args:
parent: The parent state.
unused_child: The child state, not used in the default decorator.
action: `int` the selected action in the parent state.
Returns:
`dict` with graphviz node style attributes.
"""
player = parent.current_player()
attrs = {
"label": parent.action_to_string(player, action),
"fontsize": _FONTSIZE,
"arrowsize": _ARROWSIZE
}
attrs["color"] = _PLAYER_COLORS.get(player, "black")
return attrs


class GameTree(pygraphviz.AGraph):
"""Builds `pygraphviz.AGraph` of the game tree.
Attributes:
game: A `pyspiel.Game` object.
depth_limit: Maximum depth of the tree. Optional, default=-1 (no limit).
node_decorator: Decorator function for nodes (states). Optional, default=
`treeviz.default_node_decorator`.
edge_decorator: Decorator function for edges (actions). Optional, default=
`treeviz.default_edge_decorator`.
group_terminal: Whether to display all terminal states at same level,
default=False.
group_infosets: Wheter to group infosets together, default=False.
infoset_attrs: Attributes to style infoset grouping.
kwargs: Keyword arguments passed on to `pygraphviz.AGraph.__init__`.
"""

def __init__(self,
game=None,
depth_limit=-1,
node_decorator=default_node_decorator,
edge_decorator=default_edge_decorator,
group_terminal=False,
group_infosets=False,
infoset_attrs=None,
**kwargs):

kwargs["directed"] = kwargs.get("directed", True)
super(GameTree, self).__init__(**kwargs)

# We use pygraphviz.AGraph.add_subgraph to cluster nodes, and it requires a
# default constructor. Thus game needs to be optional.
if game is None:
return

self.game = game
self._node_decorator = node_decorator
self._edge_decorator = edge_decorator

self._infosets = collections.defaultdict(lambda: [])
self._terminal_nodes = []

root = game.new_initial_state()
self.add_node(self.state_to_str(root), **self._node_decorator(root))
self._build_tree(root, 0, depth_limit)

if group_infosets:
for (player, info_state), sibblings in self._infosets.items():
cluster_name = "cluster_{}_{}".format(player, info_state)
self.add_subgraph(sibblings, cluster_name,
**(infoset_attrs or {
"style": "dashed"
}))

if group_terminal:
self.add_subgraph(self._terminal_nodes, rank="same")

def state_to_str(self, state):
"""Unique string representation of a state.
Args:
state: The state.
Returns:
String representation of state.
"""
assert not state.is_simultaneous_node()
# AGraph nodes can't have empty string == None as a key, thus we prepend " "
return " " + state.history_str()

def _build_tree(self, state, depth, depth_limit):
"""Recursively builds the game tree."""
state_str = self.state_to_str(state)

if state.is_terminal():
self._terminal_nodes.append(state_str)
return
if depth > depth_limit >= 0:
return

for action in state.legal_actions():
child = state.child(action)
child_str = self.state_to_str(child)
self.add_node(child_str, **self._node_decorator(child))
self.add_edge(state_str, child_str,
**self._edge_decorator(state, child, action))

if not child.is_chance_node() and not child.is_terminal():
player = child.current_player()
info_state = child.information_state()
self._infosets[(player, info_state)].append(child_str)

self._build_tree(child, depth + 1, depth_limit)

def _repr_svg_(self):
"""Allows to render directly in Jupyter notebooks and Google Colab."""
if not self.has_layout:
self.layout(prog="dot")
return self.draw(format="svg").decode(self.encoding)

0 comments on commit 5898968

Please sign in to comment.