Skip to content

Commit

Permalink
Adds InformationStateAsString, InformationStateTensor, and Resample i…
Browse files Browse the repository at this point in the history
…nformation state for bridge.

Fixes Observation string to be just the observation (not the information state). Does not fix the observation tensor.

PiperOrigin-RevId: 692962329
Change-Id: I7388f5d94819030154794be631df5da4c2975819
  • Loading branch information
elkhrt authored and lanctot committed Nov 22, 2024
1 parent 37a7d0e commit be54390
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 27 deletions.
98 changes: 95 additions & 3 deletions open_spiel/games/bridge/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <algorithm>
#include <array>
#include <functional>
#include <iterator>
#include <memory>
#include <set>
Expand Down Expand Up @@ -68,8 +69,8 @@ const GameType kGameType{/*short_name=*/"bridge",
GameType::RewardModel::kTerminal,
/*max_num_players=*/kNumPlayers,
/*min_num_players=*/kNumPlayers,
/*provides_information_state_string=*/false,
/*provides_information_state_tensor=*/false,
/*provides_information_state_string=*/true,
/*provides_information_state_tensor=*/true,
/*provides_observation_string=*/true,
/*provides_observation_tensor=*/true,
/*parameter_specification=*/
Expand Down Expand Up @@ -187,7 +188,35 @@ std::array<std::string, kNumSuits> FormatHand(
return cards;
}

std::string BridgeState::ObservationString(Player player) const {
std::unique_ptr<State> BridgeState::ResampleFromInfostate(
int player_id, std::function<double()> rng) const {
// Only works in the auction phase for now.
SPIEL_CHECK_TRUE(phase_ == Phase::kAuction);
std::vector<int> our_cards;
std::vector<int> other_cards;
for (int i = 0; i < kNumCards; ++i) {
if (holder_[i] == player_id) our_cards.push_back(i);
else if (holder_[i].has_value()) other_cards.push_back(i);
}
std::unique_ptr<State> new_state = GetGame()->NewInitialState();
for (int i = 0; i < kNumCards; ++i) {
if (i % kNumPlayers == player_id) {
new_state->ApplyAction(our_cards.back());
our_cards.pop_back();
} else {
const int k = static_cast<int>(rng() * other_cards.size());
new_state->ApplyAction(other_cards[k]);
other_cards[k] = other_cards.back();
other_cards.pop_back();
}
}
for (int i = kNumCards; i < history_.size(); ++i) {
new_state->ApplyAction(history_[i].action);
}
return new_state;
}

std::string BridgeState::InformationStateString(Player player) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);
if (IsTerminal()) return ToString();
Expand All @@ -203,6 +232,27 @@ std::string BridgeState::ObservationString(Player player) const {
return rv;
}

std::string BridgeState::ObservationString(Player player) const {
SPIEL_CHECK_GE(player, 0);
SPIEL_CHECK_LT(player, num_players_);
if (IsTerminal()) return ToString();
std::string rv = FormatVulnerability();
auto cards = FormatHand(player, /*mark_voids=*/true, holder_);
for (int suit = kNumSuits - 1; suit >= 0; --suit)
absl::StrAppend(&rv, cards[suit], "\n");
if (phase_ == Phase::kPlay) {
absl::StrAppend(&rv, "Contract: ", contract_.ToString(), "\n");
} else if (phase_ == Phase::kAuction && history_.size() > kNumCards) {
absl::StrAppend(
&rv, FormatAuction(/*trailing_query=*/player == CurrentPlayer()));
}
if (num_cards_played_ > 0) {
absl::StrAppend(&rv, FormatPlayObservation(/*trailing_query=*/player ==
CurrentPlayer()));
}
return rv;
}

std::array<absl::optional<Player>, kNumCards> BridgeState::OriginalDeal()
const {
SPIEL_CHECK_GE(history_.size(), kNumCards);
Expand Down Expand Up @@ -286,6 +336,42 @@ std::string BridgeState::FormatPlay() const {
return rv;
}

std::string BridgeState::FormatPlayObservation(bool trailing_query) const {
SPIEL_CHECK_GT(num_cards_played_, 0);
std::string rv;
Trick trick{kInvalidPlayer, kNoTrump, 0};
Player player = (1 + contract_.declarer) % kNumPlayers;
// Previous tricks
const int completed_tricks = num_cards_played_ / kNumPlayers;
for (int i = 0; i < completed_tricks * kNumPlayers; ++i) {
if (i % kNumPlayers == 0) {
if (i > 0) player = trick.Winner();
} else {
player = (1 + player) % kNumPlayers;
}
const int card = history_[history_.size() - num_cards_played_ + i].action;
if (i % kNumPlayers == 0) {
trick = Trick(player, contract_.trumps, card);
} else {
trick.Play(player, card);
}
if (i % kNumPlayers == 0 && i > 0)
absl::StrAppend(&rv, "Trick ", (i / kNumPlayers), " won by ");
if (Partnership(trick.Winner()) == Partnership(contract_.declarer))
absl::StrAppend(&rv, "declarer\n");
else
absl::StrAppend(&rv, "defence\n");
}
// Current trick
absl::StrAppend(&rv, "Current trick: ");
for (int i = completed_tricks * kNumPlayers; i < num_cards_played_; ++i) {
const int card = history_[history_.size() - num_cards_played_ + i].action;
absl::StrAppend(&rv, CardString(card), " ");
}
if (trailing_query) absl::StrAppend(&rv, "?");
return rv;
}

std::string BridgeState::FormatResult() const {
SPIEL_CHECK_TRUE(IsTerminal());
std::string rv;
Expand All @@ -303,6 +389,12 @@ void BridgeState::ObservationTensor(Player player,
WriteObservationTensor(player, values);
}

void BridgeState::InformationStateTensor(Player player,
absl::Span<float> values) const {
SPIEL_CHECK_EQ(values.size(), game_->ObservationTensorSize());
WriteObservationTensor(player, values);
}

void BridgeState::WriteObservationTensor(Player player,
absl::Span<float> values) const {
SPIEL_CHECK_GE(player, 0);
Expand Down
11 changes: 11 additions & 0 deletions open_spiel/games/bridge/bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,22 @@ class BridgeState : public State {
std::string ToString() const override;
bool IsTerminal() const override { return phase_ == Phase::kGameOver; }
std::vector<double> Returns() const override { return returns_; }
std::string InformationStateString(Player player) const override;
std::string ObservationString(Player player) const override;
void WriteObservationTensor(Player player, absl::Span<float> values) const;
void ObservationTensor(Player player,
absl::Span<float> values) const override;
void InformationStateTensor(Player player,
absl::Span<float> values) const override;
std::unique_ptr<State> Clone() const override {
return std::unique_ptr<State>(new BridgeState(*this));
}
std::vector<Action> LegalActions() const override;
std::vector<std::pair<Action, double>> ChanceOutcomes() const override;
std::string Serialize() const override;
void SetDoubleDummyResults(ddTableResults double_dummy_results);
std::unique_ptr<State> ResampleFromInfostate(
int player_id, std::function<double()> rng) const override;

// If the state is terminal, returns the index of the final contract, into the
// arrays returned by PossibleFinalContracts and ScoreByContract.
Expand Down Expand Up @@ -176,6 +181,7 @@ class BridgeState : public State {
std::string FormatVulnerability() const;
std::string FormatAuction(bool trailing_query) const;
std::string FormatPlay() const;
std::string FormatPlayObservation(bool trailing_query) const;
std::string FormatResult() const;

const bool use_double_dummy_result_;
Expand Down Expand Up @@ -234,6 +240,11 @@ class BridgeGame : public Game {
std::max(GetPlayTensorSize(NumTricks()), kAuctionTensorSize)};
}

std::vector<int> InformationStateTensorShape() const override {
return {kNumObservationTypes +
std::max(GetPlayTensorSize(NumTricks()), kAuctionTensorSize)};
}

int MaxGameLength() const override {
return UseDoubleDummyResult() ? kMaxAuctionLength
: kMaxAuctionLength + kNumCards;
Expand Down
1 change: 1 addition & 0 deletions open_spiel/games/bridge/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void BasicGameTests() {
testing::LoadGameTest("bridge");
testing::RandomSimTest(*LoadGame("bridge"), 3);
testing::RandomSimTest(*LoadGame("bridge(use_double_dummy_result=false)"), 3);
testing::ResampleInfostateTest(*LoadGame("bridge"), 10);
}

void DeserializeStateTest() {
Expand Down
Loading

0 comments on commit be54390

Please sign in to comment.