diff --git a/open_spiel/algorithms/tabular_q_learning.cc b/open_spiel/algorithms/tabular_q_learning.cc index e576117181..c24fc775ab 100644 --- a/open_spiel/algorithms/tabular_q_learning.cc +++ b/open_spiel/algorithms/tabular_q_learning.cc @@ -67,6 +67,15 @@ Action TabularQLearningSolver::SampleActionFromEpsilonGreedyPolicy( return GetBestAction(state, min_utility); } +void TabularQLearningSolver::SampleUntilNextStateOrTerminal(State* state) { + // Repeatedly sample while chance node, so that we end up at a decision node + while (state->IsChanceNode() && !state->IsTerminal()) { + vector legal_actions = state->LegalActions(); + state->ApplyAction( + legal_actions[absl::Uniform(rng_, 0, legal_actions.size())]); + } +} + TabularQLearningSolver::TabularQLearningSolver(std::shared_ptr game) : game_(game), depth_limit_(kDefaultDepthLimit), @@ -95,6 +104,7 @@ void TabularQLearningSolver::RunIteration() { const double min_utility = game_->MinUtility(); // Choose start state std::unique_ptr curr_state = game_->NewInitialState(); + SampleUntilNextStateOrTerminal(curr_state.get()); while (!curr_state->IsTerminal()) { const Player player = curr_state->CurrentPlayer(); @@ -104,12 +114,7 @@ void TabularQLearningSolver::RunIteration() { SampleActionFromEpsilonGreedyPolicy(*(curr_state.get()), min_utility); std::unique_ptr next_state = curr_state->Child(curr_action); - // Repeatedly sample while chance node, so that we end up at a decision node - while (next_state->IsChanceNode() && !next_state->IsTerminal()) { - vector legal_actions = next_state->LegalActions(); - next_state->ApplyAction( - legal_actions[absl::Uniform(rng_, 0, legal_actions.size())]); - } + SampleUntilNextStateOrTerminal(curr_state.get()); const double reward = next_state->Rewards()[player]; // Next q-value in perspective of player to play at curr_state (important diff --git a/open_spiel/algorithms/tabular_q_learning.h b/open_spiel/algorithms/tabular_q_learning.h index 1c900a94af..8aa0f8cdc5 100644 --- a/open_spiel/algorithms/tabular_q_learning.h +++ b/open_spiel/algorithms/tabular_q_learning.h @@ -62,6 +62,10 @@ class TabularQLearningSolver { Action SampleActionFromEpsilonGreedyPolicy(const State& state, double min_utility); + // Moves a chance node to the next decision/terminal node by sampling from + // the legal actions repeatedly + void SampleUntilNextStateOrTerminal(State* state); + std::shared_ptr game_; int depth_limit_; double epsilon_; diff --git a/open_spiel/algorithms/tabular_sarsa.cc b/open_spiel/algorithms/tabular_sarsa.cc index 6582d15985..7c2a45a7b7 100644 --- a/open_spiel/algorithms/tabular_sarsa.cc +++ b/open_spiel/algorithms/tabular_sarsa.cc @@ -61,6 +61,15 @@ Action TabularSarsaSolver::SampleActionFromEpsilonGreedyPolicy( return GetBestAction(state, min_utility); } +void TabularSarsaSolver::SampleUntilNextStateOrTerminal(State* state) { + // Repeatedly sample while chance node, so that we end up at a decision node + while (state->IsChanceNode() && !state->IsTerminal()) { + vector legal_actions = state->LegalActions(); + state->ApplyAction( + legal_actions[absl::Uniform(rng_, 0, legal_actions.size())]); + } +} + TabularSarsaSolver::TabularSarsaSolver(std::shared_ptr game) : game_(game), depth_limit_(kDefaultDepthLimit), @@ -89,6 +98,7 @@ void TabularSarsaSolver::RunIteration() { double min_utility = game_->MinUtility(); // Choose start state std::unique_ptr curr_state = game_->NewInitialState(); + SampleUntilNextStateOrTerminal(curr_state.get()); Player player = curr_state->CurrentPlayer(); // Sample action from the state using an epsilon-greedy policy @@ -97,12 +107,7 @@ void TabularSarsaSolver::RunIteration() { while (!curr_state->IsTerminal()) { std::unique_ptr next_state = curr_state->Child(curr_action); - // Repeatedly sample while chance node, so that we end up at a decision node - while (next_state->IsChanceNode() && !next_state->IsTerminal()) { - vector legal_actions = next_state->LegalActions(); - next_state->ApplyAction( - legal_actions[absl::Uniform(rng_, 0, legal_actions.size())]); - } + SampleUntilNextStateOrTerminal(curr_state.get()); const double reward = next_state->Rewards()[player]; const Action next_action = diff --git a/open_spiel/algorithms/tabular_sarsa.h b/open_spiel/algorithms/tabular_sarsa.h index bc8b3bb6e9..e840cead59 100644 --- a/open_spiel/algorithms/tabular_sarsa.h +++ b/open_spiel/algorithms/tabular_sarsa.h @@ -60,6 +60,10 @@ class TabularSarsaSolver { Action SampleActionFromEpsilonGreedyPolicy(const State& state, double min_utility); + // Moves a chance node to the next decision/terminal node by sampling from + // the legal actions repeatedly + void SampleUntilNextStateOrTerminal(State* state); + std::shared_ptr game_; int depth_limit_; double epsilon_; diff --git a/open_spiel/examples/tabular_q_learning_example.cc b/open_spiel/examples/tabular_q_learning_example.cc index ca7bb780fc..754a64537f 100644 --- a/open_spiel/examples/tabular_q_learning_example.cc +++ b/open_spiel/examples/tabular_q_learning_example.cc @@ -68,7 +68,34 @@ void SolveTicTacToe() { SPIEL_CHECK_EQ(state->Rewards()[1], 0); } +void SolveCatch() { + std::shared_ptr game = open_spiel::LoadGame("catch"); + open_spiel::algorithms::TabularQLearningSolver tabular_q_learning_solver( + game); + + int training_iter = 100000; + while (training_iter-- > 0) { + tabular_q_learning_solver.RunIteration(); + } + const absl::flat_hash_map, double>& q_values = + tabular_q_learning_solver.GetQValueTable(); + + int eval_iter = 1000; + int total_reward = 0; + while (eval_iter-- > 0) { + std::unique_ptr state = game->NewInitialState(); + while (!state->IsTerminal()) { + Action optimal_action = GetOptimalAction(q_values, state); + state->ApplyAction(optimal_action); + total_reward += state->Rewards()[0]; + } + } + + SPIEL_CHECK_GT(total_reward, 0); +} + int main(int argc, char** argv) { SolveTicTacToe(); + SolveCatch(); return 0; } diff --git a/open_spiel/examples/tabular_sarsa_example.cc b/open_spiel/examples/tabular_sarsa_example.cc index 294d192e50..43b6384162 100644 --- a/open_spiel/examples/tabular_sarsa_example.cc +++ b/open_spiel/examples/tabular_sarsa_example.cc @@ -47,15 +47,15 @@ Action GetOptimalAction( void SolveTicTacToe() { std::shared_ptr game = open_spiel::LoadGame("tic_tac_toe"); - open_spiel::algorithms::TabularSarsaSolver sarsa_solver(game); + open_spiel::algorithms::TabularSarsaSolver tabular_sarsa_solver(game); int iter = 100000; while (iter-- > 0) { - sarsa_solver.RunIteration(); + tabular_sarsa_solver.RunIteration(); } const absl::flat_hash_map, double>& q_values = - sarsa_solver.GetQValueTable(); + tabular_sarsa_solver.GetQValueTable(); std::unique_ptr state = game->NewInitialState(); while (!state->IsTerminal()) { Action optimal_action = GetOptimalAction(q_values, state); @@ -67,7 +67,33 @@ void SolveTicTacToe() { SPIEL_CHECK_EQ(state->Rewards()[1], 0); } +void SolveCatch() { + std::shared_ptr game = open_spiel::LoadGame("catch"); + open_spiel::algorithms::TabularSarsaSolver tabular_sarsa_solver(game); + + int training_iter = 100000; + while (training_iter-- > 0) { + tabular_sarsa_solver.RunIteration(); + } + const absl::flat_hash_map, double>& q_values = + tabular_sarsa_solver.GetQValueTable(); + + int eval_iter = 1000; + int total_reward = 0; + while (eval_iter-- > 0) { + std::unique_ptr state = game->NewInitialState(); + while (!state->IsTerminal()) { + Action optimal_action = GetOptimalAction(q_values, state); + state->ApplyAction(optimal_action); + total_reward += state->Rewards()[0]; + } + } + + SPIEL_CHECK_GT(total_reward, 0); +} + int main(int argc, char** argv) { SolveTicTacToe(); + SolveCatch(); return 0; }