Skip to content

Commit

Permalink
Bind the C++ CFRSolver into Python.
Browse files Browse the repository at this point in the history
This is both to get a fast version, and to be able to compare the C++ and Python versions together. This also adds a test checking that the C++ CFR solver and the Python one give the exact same result over a few steps.

PiperOrigin-RevId: 271108395
Change-Id: I84338294394e270b6cda6d90631b5dc0ca2e38d6
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Sep 26, 2019
1 parent 86a9de1 commit 51e6529
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
27 changes: 27 additions & 0 deletions open_spiel/python/algorithms/cfr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from open_spiel.python import policy
from open_spiel.python.algorithms import cfr
from open_spiel.python.algorithms import expected_game_score
from open_spiel.python.algorithms import exploitability
import pyspiel

_KUHN_UNIFORM_POLICY = policy.TabularPolicy(pyspiel.load_game("kuhn_poker"))
Expand Down Expand Up @@ -186,6 +187,32 @@ def test_policy(self):
np.testing.assert_equal(
np.asarray([0.5, 0.5]), tabular_policy.policy_for_key(info_state_str))

@parameterized.parameters([
(pyspiel.load_game("kuhn_poker"), pyspiel.CFRSolver, cfr.CFRSolver),
(pyspiel.load_game("leduc_poker"), pyspiel.CFRSolver, cfr.CFRSolver),
(pyspiel.load_game("kuhn_poker"), pyspiel.CFRPlusSolver,
cfr.CFRPlusSolver),
(pyspiel.load_game("leduc_poker"), pyspiel.CFRPlusSolver,
cfr.CFRPlusSolver),
])
def test_cpp_algorithms_identical_to_python_algorithm(self, game, cpp_class,
python_class):
cpp_solver = cpp_class(game)
python_solver = python_class(game)

for _ in range(5):
cpp_solver.evaluate_and_update_policy()
python_solver.evaluate_and_update_policy()

cpp_avg_policy = cpp_solver.average_policy()
python_avg_policy = python_solver.average_policy()

# We do not compare the policy directly as we do not have an easy way to
# convert one to the other, so we use the exploitability as a proxy.
cpp_expl = pyspiel.nash_conv(game, cpp_avg_policy)
python_expl = exploitability.nash_conv(game, python_avg_policy)
self.assertEqual(cpp_expl, python_expl)


class CFRBRTest(parameterized.TestCase, absltest.TestCase):

Expand Down
16 changes: 16 additions & 0 deletions open_spiel/python/pybind11/pyspiel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <unordered_map>

#include "open_spiel/algorithms/cfr.h"
#include "open_spiel/algorithms/evaluate_bots.h"
#include "open_spiel/algorithms/matrix_game_utils.h"
#include "open_spiel/algorithms/tabular_exploitability.h"
Expand Down Expand Up @@ -312,6 +313,21 @@ PYBIND11_MODULE(pyspiel, m) {

m.def("get_uniform_policy", &open_spiel::GetUniformPolicy);

py::class_<open_spiel::Policy> policy(m, "Policy");

py::class_<open_spiel::algorithms::CFRSolver>(m, "CFRSolver")
.def(py::init<const Game&>())
.def("evaluate_and_update_policy",
&open_spiel::algorithms::CFRSolver::EvaluateAndUpdatePolicy)
.def("average_policy",
&open_spiel::algorithms::CFRSolver::AveragePolicy);
py::class_<open_spiel::algorithms::CFRPlusSolver>(m, "CFRPlusSolver")
.def(py::init<const Game&>())
.def("evaluate_and_update_policy",
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
.def("average_policy",
&open_spiel::algorithms::CFRPlusSolver::AveragePolicy);

py::class_<open_spiel::algorithms::TrajectoryRecorder>(m,
"TrajectoryRecorder")
.def(py::init<const Game&, const std::unordered_map<std::string, int>&,
Expand Down

0 comments on commit 51e6529

Please sign in to comment.