From 56e7fe702bd8f55bed24f8d17530905bb83e18e9 Mon Sep 17 00:00:00 2001 From: Luke Marris Date: Mon, 7 Jun 2021 07:50:11 -0600 Subject: [PATCH] Expose cut threshold when calculating CCE dist. PiperOrigin-RevId: 377903873 Change-Id: I184688a45743081b3641874fa99ce4a218788b3e --- open_spiel/algorithms/corr_dist.cc | 12 ++++++++---- open_spiel/algorithms/corr_dist.h | 6 ++++-- .../python/pybind11/algorithms_corr_dist.cc | 15 +++++++++++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/open_spiel/algorithms/corr_dist.cc b/open_spiel/algorithms/corr_dist.cc index 397d3edebc..8fadbc822f 100644 --- a/open_spiel/algorithms/corr_dist.cc +++ b/open_spiel/algorithms/corr_dist.cc @@ -255,7 +255,8 @@ double CCEDist(const Game& game, const NormalFormCorrelationDevice& mu) { } CorrDistInfo CCEDist( - const Game& game, const CorrelationDevice& mu, int player) { + const Game& game, const CorrelationDevice& mu, int player, + const float prob_cut_threshold) { // Check for proper probability distribution. CheckCorrelationDeviceProbDist(mu); CorrDistConfig config; @@ -273,7 +274,7 @@ CorrDistInfo CCEDist( CCETabularPolicy policy; std::unique_ptr root = cce_game->NewInitialState(); TabularBestResponse best_response( - *cce_game, player, &policy); + *cce_game, player, &policy, prob_cut_threshold); // Do not populate on policy values to save unnecessary computation. // dist_info.on_policy_values[0] = ExpectedReturns( // *root, policy, -1, false)[player]; @@ -288,7 +289,9 @@ CorrDistInfo CCEDist( return dist_info; } -CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu) { +CorrDistInfo CCEDist( + const Game& game, const CorrelationDevice& mu, + const float prob_cut_threshold) { // Check for proper probability distribution. CheckCorrelationDeviceProbDist(mu); CorrDistConfig config; @@ -314,7 +317,8 @@ CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu) { std::unique_ptr root = cce_game->NewInitialState(); for (auto p = Player{0}; p < cce_game->NumPlayers(); ++p) { - TabularBestResponse best_response(*cce_game, p, &policy); + TabularBestResponse best_response( + *cce_game, p, &policy, prob_cut_threshold); dist_info.best_response_values[p] = best_response.Value(*root); dist_info.best_response_policies[p] = best_response.GetBestResponsePolicy(); } diff --git a/open_spiel/algorithms/corr_dist.h b/open_spiel/algorithms/corr_dist.h index 3b58a555d8..a9713f75c0 100644 --- a/open_spiel/algorithms/corr_dist.h +++ b/open_spiel/algorithms/corr_dist.h @@ -161,8 +161,10 @@ struct CorrDistInfo { // determines which policies the opponents follow (never revealed). Note that // the policies in this correlation device *can* be mixed. If values is // non-null, then it is filled with the deviation incentive of each player. -CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu); -CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu, int player); +CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu, + const float prob_cut_threshold = -1.0); +CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu, int player, + const float prob_cut_threshold = -1.0); // Distance to a correlated equilibrium in an extensive-form game. Builds a // simpler auxiliary game similar to the *FCE ones where there is a chance node diff --git a/open_spiel/python/pybind11/algorithms_corr_dist.cc b/open_spiel/python/pybind11/algorithms_corr_dist.cc index 42cc0eadbb..0a32cdf247 100644 --- a/open_spiel/python/pybind11/algorithms_corr_dist.cc +++ b/open_spiel/python/pybind11/algorithms_corr_dist.cc @@ -51,14 +51,21 @@ void init_pyspiel_algorithms_corr_dist(py::module& m) { &CorrDistInfo::conditional_best_response_policies); m.def("cce_dist", - py::overload_cast( + py::overload_cast( &open_spiel::algorithms::CCEDist), - "Returns a player's distance to a coarse-correlated equilibrium."); + "Returns a player's distance to a coarse-correlated equilibrium.", + py::arg("game"), + py::arg("correlation_device"), + py::arg("player"), + py::arg("prob_cut_threshold") = -1.0); m.def("cce_dist", - py::overload_cast( + py::overload_cast( &open_spiel::algorithms::CCEDist), - "Returns the distance to a coarse-correlated equilibrium."); + "Returns the distance to a coarse-correlated equilibrium.", + py::arg("game"), + py::arg("correlation_device"), + py::arg("prob_cut_threshold") = -1.0); m.def("ce_dist", py::overload_cast(