Skip to content

Commit

Permalink
Add a finite-recall option for observations of repeated games.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 451514583
Change-Id: Icc5a5c6e9b63aca099a07dcd4acb4b79be7239d6
  • Loading branch information
lanctot committed May 28, 2022
1 parent ca60e95 commit 8be5c5f
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 19 deletions.
58 changes: 42 additions & 16 deletions open_spiel/game_transforms/repeated_game.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace open_spiel {
namespace {

constexpr bool kDefaultEnableInformationState = false;
constexpr int kDefaultRecall = 1;

// These parameters represent the most general case. Game specific params are
// parsed once the actual stage game is supplied.
Expand All @@ -43,8 +44,9 @@ const GameType kGameType{
{{"stage_game",
GameParameter(GameParameter::Type::kGame, /*is_mandatory=*/true)},
{"num_repetitions",
GameParameter(GameParameter::Type::kInt, /*is_mandatory=*/true)}},
/*default_loadable=*/false};
GameParameter(GameParameter::Type::kInt, /*is_mandatory=*/true)},
{"recall", GameParameter(kDefaultRecall)}},
/*default_loadable=*/false};

std::shared_ptr<const Game> Factory(const GameParameters& params) {
return CreateRepeatedGame(*LoadGame(params.at("stage_game").game_value()),
Expand All @@ -57,11 +59,13 @@ REGISTER_SPIEL_GAME(kGameType, Factory);

RepeatedState::RepeatedState(std::shared_ptr<const Game> game,
std::shared_ptr<const Game> stage_game,
int num_repetitions)
int num_repetitions,
int recall)
: SimMoveState(game),
stage_game_(stage_game),
stage_game_state_(stage_game->NewInitialState()),
num_repetitions_(num_repetitions) {
num_repetitions_(num_repetitions),
recall_(recall) {
actions_history_.reserve(num_repetitions_);
rewards_history_.reserve(num_repetitions_);
}
Expand Down Expand Up @@ -133,11 +137,20 @@ std::string RepeatedState::InformationStateString(Player /*player*/) const {

std::string RepeatedState::ObservationString(Player /*player*/) const {
std::string rv;
if (actions_history_.empty()) return rv;
for (int i = 0; i < num_players_; ++i) {
absl::StrAppend(
&rv, stage_game_state_->ActionToString(i, actions_history_.back()[i]),
" ");
if (actions_history_.empty()) { return rv; }

// Starting from the back of the history, show each player's moves:
for (int j = 0;
j < recall_ && static_cast<int>(actions_history_.size()) - 1 - j >= 0;
++j) {
int hist_idx = actions_history_.size() - 1 - j;
SPIEL_CHECK_GE(hist_idx, 0);
SPIEL_CHECK_LT(hist_idx, actions_history_.size());
for (int i = 0; i < num_players_; ++i) {
absl::StrAppend(&rv,
stage_game_state_->ActionToString(i, actions_history_[hist_idx][i]),
" ");
}
}
return rv;
}
Expand Down Expand Up @@ -170,11 +183,20 @@ void RepeatedState::ObservationTensor(Player player,
if (actions_history_.empty()) return;

auto ptr = values.begin();
for (int i = 0; i < num_players_; ++i) {
ptr[actions_history_.back()[i]] = 1;
ptr += stage_game_state_->LegalActions(i).size();
// Starting from the back of the history, show each player's moves:
for (int j = 0;
j < recall_ && static_cast<int>(actions_history_.size()) - 1 - j >= 0;
j++) {
int hist_idx = static_cast<int>(actions_history_.size()) - 1 - j;
SPIEL_CHECK_GE(hist_idx, 0);
SPIEL_CHECK_LT(hist_idx, actions_history_.size());
for (int i = 0; i < num_players_; ++i) {
ptr[actions_history_[hist_idx][i]] = 1;
ptr += stage_game_state_->LegalActions(i).size();
}
}
SPIEL_CHECK_EQ(ptr, values.end());

SPIEL_CHECK_LE(ptr, values.end());
}

void RepeatedState::ObliviousObservationTensor(Player player,
Expand Down Expand Up @@ -227,7 +249,10 @@ RepeatedGame::RepeatedGame(std::shared_ptr<const Game> stage_game,
absl::optional<bool>(kDefaultEnableInformationState))),
params),
stage_game_(stage_game),
num_repetitions_(ParameterValue<int>("num_repetitions")) {}
num_repetitions_(ParameterValue<int>("num_repetitions")),
recall_(ParameterValue<int>("recall", kDefaultRecall)) {
SPIEL_CHECK_GE(recall_, 1);
}

std::shared_ptr<const Game> CreateRepeatedGame(const Game& stage_game,
const GameParameters& params) {
Expand All @@ -254,7 +279,8 @@ std::shared_ptr<const Game> CreateRepeatedGame(

std::unique_ptr<State> RepeatedGame::NewInitialState() const {
return std::unique_ptr<State>(
new RepeatedState(shared_from_this(), stage_game_, num_repetitions_));
new RepeatedState(shared_from_this(), stage_game_,
num_repetitions_, recall_));
}

std::vector<int> RepeatedGame::InformationStateTensorShape() const {
Expand All @@ -269,7 +295,7 @@ std::vector<int> RepeatedGame::InformationStateTensorShape() const {
std::vector<int> RepeatedGame::ObservationTensorShape() const {
int size = 0;
for (int i = 0; i < NumPlayers(); ++i)
size += stage_game_->NewInitialState()->LegalActions(i).size();
size += recall_ * stage_game_->NewInitialState()->LegalActions(i).size();
return {size};
}

Expand Down
8 changes: 7 additions & 1 deletion open_spiel/game_transforms/repeated_game.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@
// false).
// "stage_game" game The game that will be repeated.
// "num_repetitions" int Number of times that the game is repeated.
// "recall" int Number of previous steps that defines the
// observations when enable_infostate is false
// (default: 1).

namespace open_spiel {

class RepeatedState : public SimMoveState {
public:
RepeatedState(std::shared_ptr<const Game> game,
std::shared_ptr<const Game> stage_game, int num_repetitions);
std::shared_ptr<const Game> stage_game, int num_repetitions,
int recall);

Player CurrentPlayer() const override {
return IsTerminal() ? kTerminalPlayerId : kSimultaneousPlayerId;
Expand Down Expand Up @@ -68,6 +72,7 @@ class RepeatedState : public SimMoveState {
// to state functions (e.g. LegalActions()).
std::shared_ptr<const State> stage_game_state_;
int num_repetitions_;
int recall_;
std::vector<std::vector<Action>> actions_history_{};
std::vector<std::vector<double>> rewards_history_{};
};
Expand Down Expand Up @@ -99,6 +104,7 @@ class RepeatedGame : public SimMoveGame {
private:
std::shared_ptr<const Game> stage_game_;
const int num_repetitions_;
const int recall_;
};

// Creates a repeated game based on the stage game.
Expand Down
46 changes: 46 additions & 0 deletions open_spiel/game_transforms/repeated_game_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,51 @@ void RepeatedRockPaperScissorsDefaultsTest() {
RepeatedRockPaperScissorsTest(repeated_game);
}

void RepeatedRockPaperScissorsRecallTwoTest() {
GameParameters params;
params["num_repetitions"] = GameParameter(1000);
params["recall"] = GameParameter(2);
std::shared_ptr<const Game> repeated_game =
CreateRepeatedGame("matrix_rps", params);
SPIEL_CHECK_EQ(repeated_game->GetType().max_num_players, 2);
SPIEL_CHECK_EQ(repeated_game->GetType().min_num_players, 2);
SPIEL_CHECK_EQ(repeated_game->GetType().utility, GameType::Utility::kZeroSum);
SPIEL_CHECK_EQ(repeated_game->GetType().reward_model,
GameType::RewardModel::kRewards);
SPIEL_CHECK_TRUE(repeated_game->GetType().provides_observation_tensor);
SPIEL_CHECK_FALSE(repeated_game->GetType().provides_information_state_tensor);

// One-hot encoding of each player's previous action.
SPIEL_CHECK_EQ(repeated_game->ObservationTensorShape()[0], 12);

std::vector<std::vector<int>> observation_tensors = {
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // first
{1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, // second
{1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0} // subsequent...
};
std::vector<std::string> observation_strings = {
"", // first observation
"Rock Rock ", // second
"Rock Rock Rock Rock " // subsequent...
};

std::unique_ptr<State> state = repeated_game->NewInitialState();
int step = 0;
while (!state->IsTerminal()) {
int obs_idx = std::min(step, 2);
SPIEL_CHECK_EQ(state->ObservationString(0), observation_strings[obs_idx]);
SPIEL_CHECK_EQ(state->ObservationString(1), observation_strings[obs_idx]);
SPIEL_CHECK_TRUE(absl::c_equal(state->ObservationTensor(0),
observation_tensors[obs_idx]));
SPIEL_CHECK_TRUE(absl::c_equal(state->ObservationTensor(1),
observation_tensors[obs_idx]));
state->ApplyActions({0, 0});
step += 1;
}

SPIEL_CHECK_EQ(step, 1000);
}

void RepeatedRockPaperScissorsInfoStateEnabledTest() {
GameParameters params;
params["num_repetitions"] = GameParameter(3);
Expand Down Expand Up @@ -167,6 +212,7 @@ void RepeatedPrisonersDilemaTest() {
int main(int argc, char** argv) {
open_spiel::BasicRepeatedGameTest();
open_spiel::RepeatedRockPaperScissorsDefaultsTest();
open_spiel::RepeatedRockPaperScissorsRecallTwoTest();
open_spiel::RepeatedRockPaperScissorsInfoStateEnabledTest();
open_spiel::RepeatedPrisonersDilemaTest();
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ GameType.information = Information.PERFECT_INFORMATION
GameType.long_name = "Repeated Rock, Paper, Scissors"
GameType.max_num_players = 2
GameType.min_num_players = 2
GameType.parameter_specification = ["num_repetitions", "stage_game"]
GameType.parameter_specification = ["num_repetitions", "recall", "stage_game"]
GameType.provides_information_state_string = False
GameType.provides_information_state_tensor = False
GameType.provides_observation_string = True
Expand All @@ -19,7 +19,7 @@ GameType.utility = Utility.ZERO_SUM
NumDistinctActions() = 3
PolicyTensorShape() = [3]
MaxChanceOutcomes() = 0
GetParameters() = {num_repetitions=10,stage_game=matrix_rps()}
GetParameters() = {num_repetitions=10,recall=1,stage_game=matrix_rps()}
NumPlayers() = 2
MinUtility() = -10.0
MaxUtility() = 10.0
Expand Down

0 comments on commit 8be5c5f

Please sign in to comment.