From 56e7fe702bd8f55bed24f8d17530905bb83e18e9 Mon Sep 17 00:00:00 2001
From: Luke Marris
Date: Mon, 7 Jun 2021 07:50:11 -0600
Subject: [PATCH] Expose cut threshold when calculating CCE dist.
PiperOrigin-RevId: 377903873
Change-Id: I184688a45743081b3641874fa99ce4a218788b3e
---
open_spiel/algorithms/corr_dist.cc | 12 ++++++++----
open_spiel/algorithms/corr_dist.h | 6 ++++--
.../python/pybind11/algorithms_corr_dist.cc | 15 +++++++++++----
3 files changed, 23 insertions(+), 10 deletions(-)
diff --git a/open_spiel/algorithms/corr_dist.cc b/open_spiel/algorithms/corr_dist.cc
index 397d3edebc..8fadbc822f 100644
--- a/open_spiel/algorithms/corr_dist.cc
+++ b/open_spiel/algorithms/corr_dist.cc
@@ -255,7 +255,8 @@ double CCEDist(const Game& game, const NormalFormCorrelationDevice& mu) {
}
CorrDistInfo CCEDist(
- const Game& game, const CorrelationDevice& mu, int player) {
+ const Game& game, const CorrelationDevice& mu, int player,
+ const float prob_cut_threshold) {
// Check for proper probability distribution.
CheckCorrelationDeviceProbDist(mu);
CorrDistConfig config;
@@ -273,7 +274,7 @@ CorrDistInfo CCEDist(
CCETabularPolicy policy;
std::unique_ptr root = cce_game->NewInitialState();
TabularBestResponse best_response(
- *cce_game, player, &policy);
+ *cce_game, player, &policy, prob_cut_threshold);
// Do not populate on policy values to save unnecessary computation.
// dist_info.on_policy_values[0] = ExpectedReturns(
// *root, policy, -1, false)[player];
@@ -288,7 +289,9 @@ CorrDistInfo CCEDist(
return dist_info;
}
-CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu) {
+CorrDistInfo CCEDist(
+ const Game& game, const CorrelationDevice& mu,
+ const float prob_cut_threshold) {
// Check for proper probability distribution.
CheckCorrelationDeviceProbDist(mu);
CorrDistConfig config;
@@ -314,7 +317,8 @@ CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu) {
std::unique_ptr root = cce_game->NewInitialState();
for (auto p = Player{0}; p < cce_game->NumPlayers(); ++p) {
- TabularBestResponse best_response(*cce_game, p, &policy);
+ TabularBestResponse best_response(
+ *cce_game, p, &policy, prob_cut_threshold);
dist_info.best_response_values[p] = best_response.Value(*root);
dist_info.best_response_policies[p] = best_response.GetBestResponsePolicy();
}
diff --git a/open_spiel/algorithms/corr_dist.h b/open_spiel/algorithms/corr_dist.h
index 3b58a555d8..a9713f75c0 100644
--- a/open_spiel/algorithms/corr_dist.h
+++ b/open_spiel/algorithms/corr_dist.h
@@ -161,8 +161,10 @@ struct CorrDistInfo {
// determines which policies the opponents follow (never revealed). Note that
// the policies in this correlation device *can* be mixed. If values is
// non-null, then it is filled with the deviation incentive of each player.
-CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu);
-CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu, int player);
+CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu,
+ const float prob_cut_threshold = -1.0);
+CorrDistInfo CCEDist(const Game& game, const CorrelationDevice& mu, int player,
+ const float prob_cut_threshold = -1.0);
// Distance to a correlated equilibrium in an extensive-form game. Builds a
// simpler auxiliary game similar to the *FCE ones where there is a chance node
diff --git a/open_spiel/python/pybind11/algorithms_corr_dist.cc b/open_spiel/python/pybind11/algorithms_corr_dist.cc
index 42cc0eadbb..0a32cdf247 100644
--- a/open_spiel/python/pybind11/algorithms_corr_dist.cc
+++ b/open_spiel/python/pybind11/algorithms_corr_dist.cc
@@ -51,14 +51,21 @@ void init_pyspiel_algorithms_corr_dist(py::module& m) {
&CorrDistInfo::conditional_best_response_policies);
m.def("cce_dist",
- py::overload_cast(
+ py::overload_cast(
&open_spiel::algorithms::CCEDist),
- "Returns a player's distance to a coarse-correlated equilibrium.");
+ "Returns a player's distance to a coarse-correlated equilibrium.",
+ py::arg("game"),
+ py::arg("correlation_device"),
+ py::arg("player"),
+ py::arg("prob_cut_threshold") = -1.0);
m.def("cce_dist",
- py::overload_cast(
+ py::overload_cast(
&open_spiel::algorithms::CCEDist),
- "Returns the distance to a coarse-correlated equilibrium.");
+ "Returns the distance to a coarse-correlated equilibrium.",
+ py::arg("game"),
+ py::arg("correlation_device"),
+ py::arg("prob_cut_threshold") = -1.0);
m.def("ce_dist",
py::overload_cast(