Skip to content

Commit

Permalink
Some modifications to mfg_crowd_modelling and mfg_crowd_modelling_2d:
Browse files Browse the repository at this point in the history
- Remove many uses of ParameterValue by storing values in the game object instead
- Remove some strings from the State object since they were never used except the constructor
- Re-register mfg_crowd_modelling_2d game
- Implement serialization for mfg_crowd_modelling_2d
- Add mfg_crowd_modelling_2d to the games lists for tests
- Add a playthrough for mfg_crowd_modelling_2d

PiperOrigin-RevId: 379188040
Change-Id: I104c2795b3b7082fae7a0990e9753c2510a2104e
  • Loading branch information
lanctot authored and open_spiel@google.com committed Jun 13, 2021
1 parent 4b56534 commit bc55df5
Show file tree
Hide file tree
Showing 9 changed files with 462 additions and 91 deletions.
7 changes: 3 additions & 4 deletions open_spiel/games/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,9 @@ add_executable(crowd_modelling_test mfg/crowd_modelling_test.cc ${OPEN_SPIEL_OBJ
$<TARGET_OBJECTS:tests>)
add_test(crowd_modelling_test crowd_modelling_test)

#TODO(perolat): re-enable this when serialization working.
#add_executable(crowd_modelling_2d_test mfg/crowd_modelling_2d_test.cc ${OPEN_SPIEL_OBJECTS}
# $<TARGET_OBJECTS:tests>)
#add_test(crowd_modelling_2d_test crowd_modelling_2d_test)
add_executable(crowd_modelling_2d_test mfg/crowd_modelling_2d_test.cc ${OPEN_SPIEL_OBJECTS}
$<TARGET_OBJECTS:tests>)
add_test(crowd_modelling_2d_test crowd_modelling_2d_test)

add_executable(cursor_go_test cursor_go_test.cc
${OPEN_SPIEL_OBJECTS}
Expand Down
23 changes: 11 additions & 12 deletions open_spiel/games/mfg/crowd_modelling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,10 @@ CrowdModellingState::CrowdModellingState(std::shared_ptr<const Game> game,
horizon_(horizon),
distribution_(size_, 1. / size_) {}

CrowdModellingState::CrowdModellingState(std::shared_ptr<const Game> game,
int size, int horizon,
Player current_player,
bool is_chance_init, int x, int t,
int last_action, double return_value,
std::vector<double> distribution)
CrowdModellingState::CrowdModellingState(
std::shared_ptr<const Game> game, int size, int horizon,
Player current_player, bool is_chance_init, int x, int t, int last_action,
double return_value, const std::vector<double>& distribution)
: State(game),
size_(size),
horizon_(horizon),
Expand All @@ -99,7 +97,7 @@ CrowdModellingState::CrowdModellingState(std::shared_ptr<const Game> game,
t_(t),
last_action_(last_action),
return_value_(return_value),
distribution_(std::move(distribution)) {}
distribution_(distribution) {}

std::vector<Action> CrowdModellingState::LegalActions() const {
if (IsTerminal()) return {};
Expand Down Expand Up @@ -230,11 +228,13 @@ std::string CrowdModellingState::Serialize() const {
}

CrowdModellingGame::CrowdModellingGame(const GameParameters& params)
: Game(kGameType, params) {}
: Game(kGameType, params),
size_(ParameterValue<int>("size", kDefaultSize)),
horizon_(ParameterValue<int>("horizon", kDefaultHorizon)) {}

std::vector<int> CrowdModellingGame::ObservationTensorShape() const {
// +1 to allow for t_ == horizon.
return {ParameterValue<int>("size") + ParameterValue<int>("horizon") + 1};
return {size_ + horizon_ + 1};
}

std::unique_ptr<State> CrowdModellingGame::DeserializeState(
Expand Down Expand Up @@ -271,9 +271,8 @@ std::unique_ptr<State> CrowdModellingGame::DeserializeState(
distribution.push_back(parsed_weight);
}
return absl::make_unique<CrowdModellingState>(
shared_from_this(), ParameterValue<int>("size"),
ParameterValue<int>("horizon"), current_player, is_chance_init, x, t,
last_action, return_value, std::move(distribution));
shared_from_this(), size_, horizon_, current_player, is_chance_init, x, t,
last_action, return_value, distribution);
}

} // namespace crowd_modelling
Expand Down
21 changes: 11 additions & 10 deletions open_spiel/games/mfg/crowd_modelling.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class CrowdModellingState : public State {
public:
CrowdModellingState(std::shared_ptr<const Game> game, int size, int horizon);
CrowdModellingState(std::shared_ptr<const Game> game, int size, int horizon,
Player current_player, bool is_chance_init_, int x, int t,
Player current_player, bool is_chance_init, int x, int t,
int last_action, double return_value,
std::vector<double> distribution);
const std::vector<double>& distribution);

CrowdModellingState(const CrowdModellingState&) = default;
CrowdModellingState& operator=(const CrowdModellingState&) = default;
Expand Down Expand Up @@ -119,9 +119,8 @@ class CrowdModellingGame : public Game {
explicit CrowdModellingGame(const GameParameters& params);
int NumDistinctActions() const override { return kNumActions; }
std::unique_ptr<State> NewInitialState() const override {
return absl::make_unique<CrowdModellingState>(
shared_from_this(), ParameterValue<int>("size"),
ParameterValue<int>("horizon"));
return absl::make_unique<CrowdModellingState>(shared_from_this(), size_,
horizon_);
}
int NumPlayers() const override { return kNumPlayers; }
double MinUtility() const override {
Expand All @@ -131,19 +130,21 @@ class CrowdModellingGame : public Game {
double MaxUtility() const override {
return std::numeric_limits<double>::infinity();
}
int MaxGameLength() const override {
return ParameterValue<int>("horizon", kDefaultHorizon);
}
int MaxGameLength() const override { return horizon_; }
int MaxChanceNodesInHistory() const override {
// + 1 to account for the initial extra chance node.
return ParameterValue<int>("horizon", kDefaultHorizon) + 1;
return horizon_ + 1;
}
std::vector<int> ObservationTensorShape() const override;
int MaxChanceOutcomes() const override {
return std::max(ParameterValue<int>("size"), kNumChanceActions);
return std::max(size_, kNumChanceActions);
}
std::unique_ptr<State> DeserializeState(
const std::string& str) const override;

private:
const int size_;
const int horizon_;
};

} // namespace crowd_modelling
Expand Down
141 changes: 107 additions & 34 deletions open_spiel/games/mfg/crowd_modelling_2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ const GameType kGameType{
{
{"size", GameParameter(kDefaultSize)},
{"horizon", GameParameter(kDefaultHorizon)},
{"only_distribution_reward", GameParameter(kOnlyDistributionReward)},
{"forbidden_states", GameParameter(kForbiddenStates)},
{"initial_distribution", GameParameter(kInitialDistribution)},
{"only_distribution_reward",
GameParameter(kDefaultOnlyDistributionReward)},
{"forbidden_states", GameParameter(kDefaultForbiddenStates)},
{"initial_distribution", GameParameter(kDefaultInitialDistribution)},
{"initial_distribution_value",
GameParameter(kInitialDistributionValue)},
GameParameter(kDefaultInitialDistributionValue)},
},
/*default_loadable*/ true,
/*provides_factored_observation_string*/ false};

std::shared_ptr<const Game> Factory(const GameParameters& params) {
return std::shared_ptr<const Game>(new CrowdModellingGame(params));
return std::shared_ptr<const Game>(new CrowdModelling2dGame(params));
}

std::string StateToString(int x, int y, int t, Player player_id,
Expand Down Expand Up @@ -171,12 +172,11 @@ std::vector<int> StringListToInts(std::vector<absl::string_view> strings,
return ints;
}

// TODO(perolat): register this game when the tests can pass.
// REGISTER_SPIEL_GAME(kGameType, Factory);
REGISTER_SPIEL_GAME(kGameType, Factory);

} // namespace

CrowdModellingState::CrowdModellingState(
CrowdModelling2dState::CrowdModelling2dState(
std::shared_ptr<const Game> game, int size, int horizon,
bool only_distribution_reward, const std::string& forbidden_states,
const std::string& initial_distribution,
Expand All @@ -185,16 +185,13 @@ CrowdModellingState::CrowdModellingState(
size_(size),
horizon_(horizon),
only_distribution_reward_(only_distribution_reward),
forbidden_states_(forbidden_states),
initial_distribution_(initial_distribution),
initial_distribution_value_(initial_distribution_value),
distribution_(size_ * size_, 1. / (size_ * size_)) {
std::vector<absl::string_view> forbidden_states_list =
ProcessStringParam(forbidden_states_, size_);
ProcessStringParam(forbidden_states, size_);
std::vector<absl::string_view> initial_distribution_list =
ProcessStringParam(initial_distribution_, size_);
ProcessStringParam(initial_distribution, size_);
std::vector<absl::string_view> initial_distribution_value_list =
ProcessStringParam(initial_distribution_value_, size_);
ProcessStringParam(initial_distribution_value, size_);
SPIEL_CHECK_EQ(initial_distribution_list.size(),
initial_distribution_value_list.size());

Expand Down Expand Up @@ -247,22 +244,41 @@ CrowdModellingState::CrowdModellingState(
SPIEL_CHECK_EQ(intersection.size(), 0);
}

std::vector<Action> CrowdModellingState::LegalActions() const {
CrowdModelling2dState::CrowdModelling2dState(
std::shared_ptr<const Game> game, int size, int horizon,
bool only_distribution_reward, const std::string& forbidden_states,
const std::string& initial_distribution,
const std::string& initial_distribution_value, Player current_player,
bool is_chance_init, int x, int y, int t, int last_action,
double return_value, const std::vector<double>& distribution)
: CrowdModelling2dState(game, size, horizon, only_distribution_reward,
forbidden_states, initial_distribution,
initial_distribution_value) {
current_player_ = current_player;
is_chance_init_ = is_chance_init;
x_ = x;
y_ = y;
t_ = t;
last_action_ = last_action;
return_value_ = return_value;
}

std::vector<Action> CrowdModelling2dState::LegalActions() const {
if (IsTerminal()) return {};
if (IsChanceNode()) return LegalChanceOutcomes();
if (IsMeanFieldNode()) return {};
SPIEL_CHECK_TRUE(IsPlayerNode());
return {0, 1, 2, 3, 4};
}

ActionsAndProbs CrowdModellingState::ChanceOutcomes() const {
ActionsAndProbs CrowdModelling2dState::ChanceOutcomes() const {
if (is_chance_init_) {
return initial_distribution_action_prob_;
}
return {{0, 1. / 5}, {1, 1. / 5}, {2, 1. / 5}, {3, 1. / 5}, {4, 1. / 5}};
}

void CrowdModellingState::DoApplyAction(Action action) {
void CrowdModelling2dState::DoApplyAction(Action action) {
SPIEL_CHECK_NE(current_player_, kMeanFieldPlayerId);
return_value_ += Rewards()[0];
int xx;
Expand Down Expand Up @@ -303,16 +319,16 @@ void CrowdModellingState::DoApplyAction(Action action) {
}
}

std::string CrowdModellingState::ActionToString(Player player,
Action action) const {
std::string CrowdModelling2dState::ActionToString(Player player,
Action action) const {
if (IsChanceNode() && is_chance_init_) {
return absl::Substitute("init_state=$0", action);
}
return absl::Substitute("($0,$1)", kActionToMoveX.at(action),
kActionToMoveY.at(action));
}

std::vector<std::string> CrowdModellingState::DistributionSupport() {
std::vector<std::string> CrowdModelling2dState::DistributionSupport() {
std::vector<std::string> support;
support.reserve(size_ * size_);
for (int x = 0; x < size_; ++x) {
Expand All @@ -323,17 +339,17 @@ std::vector<std::string> CrowdModellingState::DistributionSupport() {
return support;
}

void CrowdModellingState::UpdateDistribution(
void CrowdModelling2dState::UpdateDistribution(
const std::vector<double>& distribution) {
SPIEL_CHECK_EQ(current_player_, kMeanFieldPlayerId);
SPIEL_CHECK_EQ(distribution.size(), size_ * size_);
distribution_ = distribution;
current_player_ = kChancePlayerId;
}

bool CrowdModellingState::IsTerminal() const { return t_ >= horizon_; }
bool CrowdModelling2dState::IsTerminal() const { return t_ >= horizon_; }

std::vector<double> CrowdModellingState::Rewards() const {
std::vector<double> CrowdModelling2dState::Rewards() const {
if (current_player_ != 0) {
return {0.};
}
Expand All @@ -350,28 +366,28 @@ std::vector<double> CrowdModellingState::Rewards() const {
return {r_x + r_y + r_a + r_mu};
}

std::vector<double> CrowdModellingState::Returns() const {
std::vector<double> CrowdModelling2dState::Returns() const {
return {return_value_ + Rewards()[0]};
}

std::string CrowdModellingState::ToString() const {
std::string CrowdModelling2dState::ToString() const {
return StateToString(x_, y_, t_, current_player_, is_chance_init_);
}

std::string CrowdModellingState::InformationStateString(Player player) const {
std::string CrowdModelling2dState::InformationStateString(Player player) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);
return HistoryString();
}

std::string CrowdModellingState::ObservationString(Player player) const {
std::string CrowdModelling2dState::ObservationString(Player player) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);
return ToString();
}

void CrowdModellingState::ObservationTensor(Player player,
absl::Span<float> values) const {
void CrowdModelling2dState::ObservationTensor(Player player,
absl::Span<float> values) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);
SPIEL_CHECK_EQ(values.size(), 2 * size_ + horizon_);
Expand All @@ -387,16 +403,73 @@ void CrowdModellingState::ObservationTensor(Player player,
values[size_ + t_] = 1.;
}

std::unique_ptr<State> CrowdModellingState::Clone() const {
return std::unique_ptr<State>(new CrowdModellingState(*this));
std::unique_ptr<State> CrowdModelling2dState::Clone() const {
return std::unique_ptr<State>(new CrowdModelling2dState(*this));
}

CrowdModellingGame::CrowdModellingGame(const GameParameters& params)
: Game(kGameType, params) {}
std::string CrowdModelling2dState::Serialize() const {
std::string out =
absl::StrCat(current_player_, ",", is_chance_init_, ",", x_, ",", y_, ",",
t_, ",", last_action_, ",", return_value_, "\n");
absl::StrAppend(&out, absl::StrJoin(distribution_, ","));
return out;
}

std::vector<int> CrowdModellingGame::ObservationTensorShape() const {
CrowdModelling2dGame::CrowdModelling2dGame(const GameParameters& params)
: Game(kGameType, params),
size_(ParameterValue<int>("size", kDefaultSize)),
horizon_(ParameterValue<int>("horizon", kDefaultHorizon)),
only_distribution_reward_(ParameterValue<int>(
"only_distribution_reward", kDefaultOnlyDistributionReward)),
forbidden_states_(ParameterValue<std::string>("forbidden_states",
kDefaultForbiddenStates)),
initial_distribution_(ParameterValue<std::string>(
"initial_distribution", kDefaultInitialDistribution)),
initial_distribution_value_(ParameterValue<std::string>(
"initial_distribution_value", kDefaultInitialDistributionValue)) {}

std::vector<int> CrowdModelling2dGame::ObservationTensorShape() const {
return {2 * ParameterValue<int>("size") + ParameterValue<int>("horizon")};
}

std::unique_ptr<State> CrowdModelling2dGame::DeserializeState(
const std::string& str) const {
std::vector<std::string> lines = absl::StrSplit(str, '\n');
if (lines.size() != 2) {
SpielFatalError(absl::StrCat("Expected 2 lines in serialized state, got: ",
lines.size()));
}
Player current_player;
int is_chance_init, x, y, t, last_action;
double return_value;
std::vector<double> distribution;

std::vector<std::string> properties = absl::StrSplit(lines[0], ',');
if (properties.size() != 7) {
SpielFatalError(
absl::StrCat("Expected 7 properties for serialized state, got: ",
properties.size()));
}
SPIEL_CHECK_TRUE(absl::SimpleAtoi(properties[0], &current_player));
SPIEL_CHECK_TRUE(absl::SimpleAtoi(properties[1], &is_chance_init));
SPIEL_CHECK_TRUE(absl::SimpleAtoi(properties[2], &x));
SPIEL_CHECK_TRUE(absl::SimpleAtoi(properties[3], &y));
SPIEL_CHECK_TRUE(absl::SimpleAtoi(properties[4], &t));
SPIEL_CHECK_TRUE(absl::SimpleAtoi(properties[5], &last_action));
SPIEL_CHECK_TRUE(absl::SimpleAtod(properties[6], &return_value));
std::vector<std::string> serialized_distrib = absl::StrSplit(lines[1], ',');
distribution.reserve(serialized_distrib.size());
for (std::string& v : serialized_distrib) {
double parsed_weight;
SPIEL_CHECK_TRUE(absl::SimpleAtod(v, &parsed_weight));
distribution.push_back(parsed_weight);
}
return absl::make_unique<CrowdModelling2dState>(
shared_from_this(), size_, horizon_, only_distribution_reward_,
forbidden_states_, initial_distribution_, initial_distribution_value_,
current_player, is_chance_init, x, y, t, last_action, return_value,
distribution);
}

} // namespace crowd_modelling_2d
} // namespace open_spiel
Loading

0 comments on commit bc55df5

Please sign in to comment.