diff --git a/open_spiel/game_transforms/repeated_game.cc b/open_spiel/game_transforms/repeated_game.cc index f651fa8222..2537febb45 100644 --- a/open_spiel/game_transforms/repeated_game.cc +++ b/open_spiel/game_transforms/repeated_game.cc @@ -22,6 +22,7 @@ namespace open_spiel { namespace { constexpr bool kDefaultEnableInformationState = false; +constexpr int kDefaultRecall = 1; // These parameters represent the most general case. Game specific params are // parsed once the actual stage game is supplied. @@ -43,8 +44,9 @@ const GameType kGameType{ {{"stage_game", GameParameter(GameParameter::Type::kGame, /*is_mandatory=*/true)}, {"num_repetitions", - GameParameter(GameParameter::Type::kInt, /*is_mandatory=*/true)}}, - /*default_loadable=*/false}; + GameParameter(GameParameter::Type::kInt, /*is_mandatory=*/true)}, + {"recall", GameParameter(kDefaultRecall)}}, + /*default_loadable=*/false}; std::shared_ptr Factory(const GameParameters& params) { return CreateRepeatedGame(*LoadGame(params.at("stage_game").game_value()), @@ -57,11 +59,13 @@ REGISTER_SPIEL_GAME(kGameType, Factory); RepeatedState::RepeatedState(std::shared_ptr game, std::shared_ptr stage_game, - int num_repetitions) + int num_repetitions, + int recall) : SimMoveState(game), stage_game_(stage_game), stage_game_state_(stage_game->NewInitialState()), - num_repetitions_(num_repetitions) { + num_repetitions_(num_repetitions), + recall_(recall) { actions_history_.reserve(num_repetitions_); rewards_history_.reserve(num_repetitions_); } @@ -133,11 +137,20 @@ std::string RepeatedState::InformationStateString(Player /*player*/) const { std::string RepeatedState::ObservationString(Player /*player*/) const { std::string rv; - if (actions_history_.empty()) return rv; - for (int i = 0; i < num_players_; ++i) { - absl::StrAppend( - &rv, stage_game_state_->ActionToString(i, actions_history_.back()[i]), - " "); + if (actions_history_.empty()) { return rv; } + + // Starting from the back of the history, show each player's moves: + for (int j = 0; + j < recall_ && static_cast(actions_history_.size()) - 1 - j >= 0; + ++j) { + int hist_idx = actions_history_.size() - 1 - j; + SPIEL_CHECK_GE(hist_idx, 0); + SPIEL_CHECK_LT(hist_idx, actions_history_.size()); + for (int i = 0; i < num_players_; ++i) { + absl::StrAppend(&rv, + stage_game_state_->ActionToString(i, actions_history_[hist_idx][i]), + " "); + } } return rv; } @@ -170,11 +183,20 @@ void RepeatedState::ObservationTensor(Player player, if (actions_history_.empty()) return; auto ptr = values.begin(); - for (int i = 0; i < num_players_; ++i) { - ptr[actions_history_.back()[i]] = 1; - ptr += stage_game_state_->LegalActions(i).size(); + // Starting from the back of the history, show each player's moves: + for (int j = 0; + j < recall_ && static_cast(actions_history_.size()) - 1 - j >= 0; + j++) { + int hist_idx = static_cast(actions_history_.size()) - 1 - j; + SPIEL_CHECK_GE(hist_idx, 0); + SPIEL_CHECK_LT(hist_idx, actions_history_.size()); + for (int i = 0; i < num_players_; ++i) { + ptr[actions_history_[hist_idx][i]] = 1; + ptr += stage_game_state_->LegalActions(i).size(); + } } - SPIEL_CHECK_EQ(ptr, values.end()); + + SPIEL_CHECK_LE(ptr, values.end()); } void RepeatedState::ObliviousObservationTensor(Player player, @@ -227,7 +249,10 @@ RepeatedGame::RepeatedGame(std::shared_ptr stage_game, absl::optional(kDefaultEnableInformationState))), params), stage_game_(stage_game), - num_repetitions_(ParameterValue("num_repetitions")) {} + num_repetitions_(ParameterValue("num_repetitions")), + recall_(ParameterValue("recall", kDefaultRecall)) { + SPIEL_CHECK_GE(recall_, 1); +} std::shared_ptr CreateRepeatedGame(const Game& stage_game, const GameParameters& params) { @@ -254,7 +279,8 @@ std::shared_ptr CreateRepeatedGame( std::unique_ptr RepeatedGame::NewInitialState() const { return std::unique_ptr( - new RepeatedState(shared_from_this(), stage_game_, num_repetitions_)); + new RepeatedState(shared_from_this(), stage_game_, + num_repetitions_, recall_)); } std::vector RepeatedGame::InformationStateTensorShape() const { @@ -269,7 +295,7 @@ std::vector RepeatedGame::InformationStateTensorShape() const { std::vector RepeatedGame::ObservationTensorShape() const { int size = 0; for (int i = 0; i < NumPlayers(); ++i) - size += stage_game_->NewInitialState()->LegalActions(i).size(); + size += recall_ * stage_game_->NewInitialState()->LegalActions(i).size(); return {size}; } diff --git a/open_spiel/game_transforms/repeated_game.h b/open_spiel/game_transforms/repeated_game.h index ff3e4752fe..dc3e024ede 100644 --- a/open_spiel/game_transforms/repeated_game.h +++ b/open_spiel/game_transforms/repeated_game.h @@ -31,13 +31,17 @@ // false). // "stage_game" game The game that will be repeated. // "num_repetitions" int Number of times that the game is repeated. +// "recall" int Number of previous steps that defines the +// observations when enable_infostate is false +// (default: 1). namespace open_spiel { class RepeatedState : public SimMoveState { public: RepeatedState(std::shared_ptr game, - std::shared_ptr stage_game, int num_repetitions); + std::shared_ptr stage_game, int num_repetitions, + int recall); Player CurrentPlayer() const override { return IsTerminal() ? kTerminalPlayerId : kSimultaneousPlayerId; @@ -68,6 +72,7 @@ class RepeatedState : public SimMoveState { // to state functions (e.g. LegalActions()). std::shared_ptr stage_game_state_; int num_repetitions_; + int recall_; std::vector> actions_history_{}; std::vector> rewards_history_{}; }; @@ -99,6 +104,7 @@ class RepeatedGame : public SimMoveGame { private: std::shared_ptr stage_game_; const int num_repetitions_; + const int recall_; }; // Creates a repeated game based on the stage game. diff --git a/open_spiel/game_transforms/repeated_game_test.cc b/open_spiel/game_transforms/repeated_game_test.cc index 3858df1c18..aeb3e05b25 100644 --- a/open_spiel/game_transforms/repeated_game_test.cc +++ b/open_spiel/game_transforms/repeated_game_test.cc @@ -91,6 +91,51 @@ void RepeatedRockPaperScissorsDefaultsTest() { RepeatedRockPaperScissorsTest(repeated_game); } +void RepeatedRockPaperScissorsRecallTwoTest() { + GameParameters params; + params["num_repetitions"] = GameParameter(1000); + params["recall"] = GameParameter(2); + std::shared_ptr repeated_game = + CreateRepeatedGame("matrix_rps", params); + SPIEL_CHECK_EQ(repeated_game->GetType().max_num_players, 2); + SPIEL_CHECK_EQ(repeated_game->GetType().min_num_players, 2); + SPIEL_CHECK_EQ(repeated_game->GetType().utility, GameType::Utility::kZeroSum); + SPIEL_CHECK_EQ(repeated_game->GetType().reward_model, + GameType::RewardModel::kRewards); + SPIEL_CHECK_TRUE(repeated_game->GetType().provides_observation_tensor); + SPIEL_CHECK_FALSE(repeated_game->GetType().provides_information_state_tensor); + + // One-hot encoding of each player's previous action. + SPIEL_CHECK_EQ(repeated_game->ObservationTensorShape()[0], 12); + + std::vector> observation_tensors = { + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // first + {1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, // second + {1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0} // subsequent... + }; + std::vector observation_strings = { + "", // first observation + "Rock Rock ", // second + "Rock Rock Rock Rock " // subsequent... + }; + + std::unique_ptr state = repeated_game->NewInitialState(); + int step = 0; + while (!state->IsTerminal()) { + int obs_idx = std::min(step, 2); + SPIEL_CHECK_EQ(state->ObservationString(0), observation_strings[obs_idx]); + SPIEL_CHECK_EQ(state->ObservationString(1), observation_strings[obs_idx]); + SPIEL_CHECK_TRUE(absl::c_equal(state->ObservationTensor(0), + observation_tensors[obs_idx])); + SPIEL_CHECK_TRUE(absl::c_equal(state->ObservationTensor(1), + observation_tensors[obs_idx])); + state->ApplyActions({0, 0}); + step += 1; + } + + SPIEL_CHECK_EQ(step, 1000); +} + void RepeatedRockPaperScissorsInfoStateEnabledTest() { GameParameters params; params["num_repetitions"] = GameParameter(3); @@ -167,6 +212,7 @@ void RepeatedPrisonersDilemaTest() { int main(int argc, char** argv) { open_spiel::BasicRepeatedGameTest(); open_spiel::RepeatedRockPaperScissorsDefaultsTest(); + open_spiel::RepeatedRockPaperScissorsRecallTwoTest(); open_spiel::RepeatedRockPaperScissorsInfoStateEnabledTest(); open_spiel::RepeatedPrisonersDilemaTest(); } diff --git a/open_spiel/integration_tests/playthroughs/repeated_game(stage_game=matrix_rps(),num_repetitions=10).txt b/open_spiel/integration_tests/playthroughs/repeated_game(stage_game=matrix_rps(),num_repetitions=10).txt index 97916808d1..e85b79cee9 100644 --- a/open_spiel/integration_tests/playthroughs/repeated_game(stage_game=matrix_rps(),num_repetitions=10).txt +++ b/open_spiel/integration_tests/playthroughs/repeated_game(stage_game=matrix_rps(),num_repetitions=10).txt @@ -6,7 +6,7 @@ GameType.information = Information.PERFECT_INFORMATION GameType.long_name = "Repeated Rock, Paper, Scissors" GameType.max_num_players = 2 GameType.min_num_players = 2 -GameType.parameter_specification = ["num_repetitions", "stage_game"] +GameType.parameter_specification = ["num_repetitions", "recall", "stage_game"] GameType.provides_information_state_string = False GameType.provides_information_state_tensor = False GameType.provides_observation_string = True @@ -19,7 +19,7 @@ GameType.utility = Utility.ZERO_SUM NumDistinctActions() = 3 PolicyTensorShape() = [3] MaxChanceOutcomes() = 0 -GetParameters() = {num_repetitions=10,stage_game=matrix_rps()} +GetParameters() = {num_repetitions=10,recall=1,stage_game=matrix_rps()} NumPlayers() = 2 MinUtility() = -10.0 MaxUtility() = 10.0