Skip to content

Commit

Permalink
Add debug function to rebuild a state from history string.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 541670375
Change-Id: Id10a3d2e050653ffa0496bab06769b2eb17b33fc
  • Loading branch information
lanctot committed Jun 22, 2023
1 parent 760013d commit 5d41fff
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
7 changes: 6 additions & 1 deletion open_spiel/python/pybind11/pyspiel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class SpielException : public std::exception {
std::string message_;
};

// Definintion of our Python module.
// Definition of our Python module.
PYBIND11_MODULE(pyspiel, m) {
m.doc() = "Open Spiel";

Expand Down Expand Up @@ -615,6 +615,11 @@ PYBIND11_MODULE(pyspiel, m) {
py::arg("mean_field_population") = -1, py::arg("observer") = nullptr,
"Run the C++ tests on a game");

m.def("build_state_from_history_string", BuildStateFromHistoryString,
"Builds a state from a game string and history string.",
py::arg("game_string"), py::arg("history_string"),
py::arg("max_steps") = -1);

// Set an error handler that will raise exceptions. These exceptions are for
// the Python interface only. When used from C++, OpenSpiel will never raise
// exceptions - the process will be terminated instead.
Expand Down
45 changes: 45 additions & 0 deletions open_spiel/spiel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#include "open_spiel/abseil-cpp/absl/algorithm/container.h"
#include "open_spiel/abseil-cpp/absl/container/btree_map.h"
#include "open_spiel/abseil-cpp/absl/random/distributions.h"
#include "open_spiel/abseil-cpp/absl/strings/ascii.h"
#include "open_spiel/abseil-cpp/absl/strings/match.h"
#include "open_spiel/abseil-cpp/absl/strings/numbers.h"
#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
#include "open_spiel/abseil-cpp/absl/strings/str_format.h"
#include "open_spiel/abseil-cpp/absl/strings/str_join.h"
Expand Down Expand Up @@ -832,4 +834,47 @@ void SpielFatalErrorWithStateInfo(const std::string& error_msg,
SpielFatalError(absl::StrCat(error_msg, "Serialized state:\n", info));
}

std::pair<std::shared_ptr<const Game>,
std::unique_ptr<State>> BuildStateFromHistoryString(
const std::string& game_string,
const std::string& history,
int max_steps) {
std::pair<std::shared_ptr<const Game>, std::unique_ptr<State>> game_and_state;
game_and_state.first = LoadGame(game_string);
game_and_state.second = game_and_state.first->NewInitialState();
std::string history_copy(absl::StripAsciiWhitespace(history));
if (history_copy[0] == '[') {
history_copy = history_copy.substr(1);
}
if (history_copy[history_copy.length() - 1] == ']') {
history_copy = history_copy.substr(0, history_copy.length() - 1);
}

std::vector<Action> legal_actions;
State* state = game_and_state.second.get();
int steps = 0;
std::vector<std::string> parts = absl::StrSplit(history_copy, ',');
for (const std::string& part : parts) {
if (max_steps > 0 && steps >= max_steps) {
break;
}
Action action;
bool atoi_ret = absl::SimpleAtoi(absl::StripAsciiWhitespace(part), &action);
if (!atoi_ret) {
SpielFatalError(absl::StrCat("Problem parsing action: ", part));
}
legal_actions = state->LegalActions();
if (absl::c_find(legal_actions, action) == legal_actions.end()) {
SpielFatalError(absl::StrCat("Illegal move detected!\nState:\n",
state->ToString(), "\nAction: ", action,
" (", state->ActionToString(action), ")\n",
"History: ", state->HistoryString()));
}
state->ApplyAction(action);
steps++;
}

return game_and_state;
}

} // namespace open_spiel
12 changes: 12 additions & 0 deletions open_spiel/spiel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,18 @@ void SpielFatalErrorWithStateInfo(const std::string& error_msg,
const Game& game,
const State& state);


// Builds the state from a history string. Checks legalities of every action
// on the way. The history string is a comma-separated actions with whitespace
// allowed, and can include square brackets on either side:
// E.g. "[1, 3, 4, 5, 6]" and "57,12,72,85" are both valid.
// Proceeds up to a maximum of max_steps, unless max_steps is negative, in
// which case it proceeds until the end of the sequence.
std::pair<std::shared_ptr<const Game>,
std::unique_ptr<State>> BuildStateFromHistoryString(
const std::string& game_string, const std::string& history,
int max_steps = -1);

} // namespace open_spiel

#endif // OPEN_SPIEL_SPIEL_H_
12 changes: 11 additions & 1 deletion open_spiel/tests/console_play_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,22 @@ void ConsolePlayTest(
bool applied_action = true;
std::unique_ptr<State> new_state;

while (!state->IsTerminal()) {
while (true) {
if (applied_action) {
std::cout << state->ToString() << std::endl << std::endl;
}
applied_action = false;
Player player = state->CurrentPlayer();
std::vector<Action> legal_actions = state->LegalActions();

if (state->IsTerminal()) {
std::cout << "Warning! State is terminal. Returns: ";
for (Player p = 0; p < game.NumPlayers(); ++p) {
std::cout << state->PlayerReturn(p) << " ";
}
std::cout << std::endl;
}

if (bots != nullptr && bots->at(player) != nullptr) {
Action action = bots->at(player)->Step(*state);
std::cout << "Bot chose action: " << state->ActionToString(player, action)
Expand All @@ -109,12 +117,14 @@ void ConsolePlayTest(
if (line.empty()) {
PrintHelpMenu();
} else if (line == "#b") {
Action last_action = state->History().back();
new_state = game.NewInitialState();
std::vector<Action> history = state->History();
for (int i = 0; i < history.size() - 1; ++i) {
new_state->ApplyAction(history[i]);
}
state = std::move(new_state);
std::cout << "Popped action: " << last_action << std::endl;
applied_action = true;
} else if (line == "#q") {
return;
Expand Down

0 comments on commit 5d41fff

Please sign in to comment.