Skip to content

Commit

Permalink
Merge pull request google-deepmind#1038 from giogix2:eligibility_trac…
Browse files Browse the repository at this point in the history
…es_for_tabular_q_learning

PiperOrigin-RevId: 523971457
Change-Id: Ica15bfa720a64b10b4386b6aa754cf8f64d0c6b9
  • Loading branch information
lanctot committed Apr 17, 2023
2 parents 33d1ef1 + 3961e99 commit ca8affc
Show file tree
Hide file tree
Showing 9 changed files with 655 additions and 26 deletions.
8 changes: 8 additions & 0 deletions open_spiel/algorithms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ add_executable(tabular_exploitability_test tabular_exploitability_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(tabular_exploitability_test tabular_exploitability_test)

add_executable(tabular_sarsa_test tabular_sarsa_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(tabular_sarsa_test tabular_sarsa_test)

add_executable(tabular_q_learning_test tabular_q_learning_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(tabular_q_learning_test tabular_q_learning_test)

add_executable(tensor_game_utils_test tensor_game_utils_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(tensor_game_utils_test tensor_game_utils_test)
Expand Down
50 changes: 38 additions & 12 deletions open_spiel/algorithms/tabular_q_learning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ Action TabularQLearningSolver::GetBestAction(const State& state,
double min_utility) {
vector<Action> legal_actions = state.LegalActions();
SPIEL_CHECK_GT(legal_actions.size(), 0);
Action best_action = legal_actions[0];
const auto state_str = state.ToString();

Action best_action = legal_actions[0];
double value = min_utility;
for (const Action& action : legal_actions) {
double q_val = values_[{state.ToString(), action}];
double q_val = values_[{state_str, action}];
if (q_val >= value) {
value = q_val;
best_action = action;
Expand All @@ -54,19 +55,21 @@ double TabularQLearningSolver::GetBestActionValue(const State& state,
return values_[{state.ToString(), GetBestAction(state, min_utility)}];
}

Action TabularQLearningSolver::SampleActionFromEpsilonGreedyPolicy(
std::pair<Action, bool>
TabularQLearningSolver::SampleActionFromEpsilonGreedyPolicy(
const State& state, double min_utility) {
vector<Action> legal_actions = state.LegalActions();
if (legal_actions.empty()) {
return kInvalidAction;
return {kInvalidAction, false};
}

if (absl::Uniform(rng_, 0.0, 1.0) < epsilon_) {
// Choose a random action
return legal_actions[absl::Uniform<int>(rng_, 0, legal_actions.size())];
return {legal_actions[absl::Uniform<int>(rng_, 0, legal_actions.size())],
true};
}
// Choose the best action
return GetBestAction(state, min_utility);
return {GetBestAction(state, min_utility), false};
}

void TabularQLearningSolver::SampleUntilNextStateOrTerminal(State* state) {
Expand All @@ -84,8 +87,8 @@ TabularQLearningSolver::TabularQLearningSolver(std::shared_ptr<const Game> game)
learning_rate_(kDefaultLearningRate),
discount_factor_(kDefaultDiscountFactor),
lambda_(kDefaultLambda) {
// Only support lambda=0 for now.
SPIEL_CHECK_EQ(lambda_, 0);
SPIEL_CHECK_LE(lambda_, 1);
SPIEL_CHECK_GE(lambda_, 0);

// Currently only supports 1-player or 2-player zero sum games
SPIEL_CHECK_TRUE(game_->NumPlayers() == 1 || game_->NumPlayers() == 2);
Expand All @@ -109,8 +112,8 @@ TabularQLearningSolver::TabularQLearningSolver(
learning_rate_(learning_rate),
discount_factor_(discount_factor),
lambda_(lambda) {
// Only support lambda=0 for now.
SPIEL_CHECK_EQ(lambda_, 0);
SPIEL_CHECK_LE(lambda_, 1);
SPIEL_CHECK_GE(lambda_, 0);

// Currently only supports 1-player or 2-player zero sum games
SPIEL_CHECK_TRUE(game_->NumPlayers() == 1 || game_->NumPlayers() == 2);
Expand Down Expand Up @@ -140,7 +143,7 @@ void TabularQLearningSolver::RunIteration() {
const Player player = curr_state->CurrentPlayer();

// Sample action from the state using an epsilon-greedy policy
Action curr_action =
auto [curr_action, chosen_uniformly] =
SampleActionFromEpsilonGreedyPolicy(*curr_state, min_utility);

std::unique_ptr<State> next_state = curr_state->Child(curr_action);
Expand All @@ -158,7 +161,30 @@ void TabularQLearningSolver::RunIteration() {
double new_q_value = reward + discount_factor_ * next_q_value;

double prev_q_val = values_[{key, curr_action}];
values_[{key, curr_action}] += learning_rate_ * (new_q_value - prev_q_val);
if (lambda_ == 0) {
// If lambda_ is equal to zero run Q-learning as usual.
// It's not necessary to update eligibility traces.
values_[{key, curr_action}] +=
learning_rate_ * (new_q_value - prev_q_val);
} else {
double lambda =
player != next_state->CurrentPlayer() ? -lambda_ : lambda_;
eligibility_traces_[{key, curr_action}] += 1;

for (const auto& q_cell : values_) {
std::string state = q_cell.first.first;
Action action = q_cell.first.second;

values_[{state, action}] += learning_rate_ *
(new_q_value - prev_q_val) *
eligibility_traces_[{state, action}];
if (chosen_uniformly) {
eligibility_traces_[{state, action}] = 0;
} else {
eligibility_traces_[{state, action}] *= discount_factor_ * lambda;
}
}
}

curr_state = std::move(next_state);
}
Expand Down
19 changes: 15 additions & 4 deletions open_spiel/algorithms/tabular_q_learning.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ namespace algorithms {
//
// Based on the implementation in Sutton and Barto, Intro to RL. Second Edition,
// 2018. Section 6.5.
// Note: current implementation only supports full bootstrapping (lambda = 0).
//
// Includes implementation of Watkins’s Q(lambda) which can be found in
// Sutton and Barto, Intro to RL. Second Edition, 2018. Section 12.10.
// (E.g. https://www.andrew.cmu.edu/course/10-703/textbook/BartoSutton.pdf)
// Eligibility traces are implemented with the "accumulate"
// method (+1 at each iteration) instead of "replace" implementation
// (doesn't sum trace values). Parameter lambda_ determines the level
// of bootstraping.

class TabularQLearningSolver {
static inline constexpr double kDefaultDepthLimit = -1;
Expand Down Expand Up @@ -63,9 +70,11 @@ class TabularQLearningSolver {
double GetBestActionValue(const State& state, double min_utility);

// Given a player and a state, gets the action, sampled from an epsilon-greedy
// policy
Action SampleActionFromEpsilonGreedyPolicy(const State& state,
double min_utility);
// policy. Returns <action, chosen_uniformly> where the second element
// indicates whether an action was chosen uniformly (which occurs with epsilon
// chance).
std::pair<Action, bool> SampleActionFromEpsilonGreedyPolicy(
const State& state, double min_utility);

// Moves a chance node to the next decision/terminal node by sampling from
// the legal actions repeatedly
Expand All @@ -79,6 +88,8 @@ class TabularQLearningSolver {
double lambda_;
std::mt19937 rng_;
absl::flat_hash_map<std::pair<std::string, Action>, double> values_;
absl::flat_hash_map<std::pair<std::string, Action>, double>
eligibility_traces_;
};

} // namespace algorithms
Expand Down
228 changes: 228 additions & 0 deletions open_spiel/algorithms/tabular_q_learning_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// Copyright 2023 DeepMind Technologies Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "open_spiel/algorithms/tabular_q_learning.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "open_spiel/abseil-cpp/absl/random/distributions.h"
#include "open_spiel/abseil-cpp/absl/random/random.h"
#include "open_spiel/games/catch.h"
#include "open_spiel/spiel.h"

namespace open_spiel {
namespace {

Action GetOptimalAction(
absl::flat_hash_map<std::pair<std::string, Action>, double> q_values,
const std::unique_ptr<State> &state) {
std::vector<Action> legal_actions = state->LegalActions();
const auto state_str = state->ToString();

Action optimal_action = open_spiel::kInvalidAction;
double value = -1;
for (const Action &action : legal_actions) {
double q_val = q_values[{state_str, action}];
if (q_val >= value) {
value = q_val;
optimal_action = action;
}
}
return optimal_action;
}

Action GetRandomAction(const std::unique_ptr<State> &state, int seed) {
std::vector<Action> legal_actions = state->LegalActions();
if (legal_actions.empty()) {
return kInvalidAction;
}
std::mt19937 rng(seed);
return legal_actions[absl::Uniform<int>(rng, 0, legal_actions.size())];
}

double PlayCatch(
absl::flat_hash_map<std::pair<std::string, Action>, double> q_values,
const std::unique_ptr<State> &state, double seed) {
// First action determines the starting column. Do the first action before the
// main loop, where the optimal action is chosen.
// Example: Initial state with random seed 42
// ...o.
// .....
// .....
// .....
// .....
// .....
// .....
// .....
// .....
// ..x..
std::mt19937 gen(seed);
std::uniform_int_distribution<int> distribution(0,
catch_::kDefaultColumns - 1);
int ball_starting_column = distribution(gen);
state->ApplyAction(ball_starting_column);

while (!state->IsTerminal()) {
Action optimal_action = GetOptimalAction(q_values, state);
state->ApplyAction(optimal_action);
}

return state->Rewards()[0];
}

std::unique_ptr<open_spiel::algorithms::TabularQLearningSolver> QLearningSolver(
std::shared_ptr<const Game> game, double lambda) {
return std::make_unique<open_spiel::algorithms::TabularQLearningSolver>(
/*game=*/game,
/*depth_limit=*/-1.0,
/*epsilon=*/0.1,
/*learning_rate=*/0.01,
/*discount_factor=*/0.99,
/*lambda=*/lambda);
}

void TabularQLearningTest_Catch_Lambda00_Loss() {
// Classic Q-learning. No bootstraping (lambda=0.0)
// Player loses after only 1 train iteration.
std::shared_ptr<const Game> game = LoadGame("catch");
auto tabular_q_learning_solver = QLearningSolver(game, 0);

tabular_q_learning_solver->RunIteration();
const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_q_learning_solver->GetQValueTable();
std::unique_ptr<State> state = game->NewInitialState();

double reward = PlayCatch(q_values, state, 42);
SPIEL_CHECK_EQ(reward, -1);
}

void TabularQLearningTest_Catch_Lambda00_Win() {
// Classic Q-learning. No bootstraping (lambda=0.0)
// Player wins after 100 train iterations
std::shared_ptr<const Game> game = LoadGame("catch");
auto tabular_q_learning_solver = QLearningSolver(game, 0);

for (int i = 1; i < 100; i++) {
tabular_q_learning_solver->RunIteration();
}
const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_q_learning_solver->GetQValueTable();
std::unique_ptr<State> state = game->NewInitialState();

double reward = PlayCatch(q_values, state, 42);
SPIEL_CHECK_EQ(reward, 1);
}

void TabularQLearningTest_Catch_Lambda01_Win() {
// Player wins after 100 train iterations
std::shared_ptr<const Game> game = LoadGame("catch");
auto tabular_q_learning_solver = QLearningSolver(game, 0.1);

for (int i = 1; i < 100; i++) {
tabular_q_learning_solver->RunIteration();
}
const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_q_learning_solver->GetQValueTable();
std::unique_ptr<State> state = game->NewInitialState();

double reward = PlayCatch(q_values, state, 42);
SPIEL_CHECK_EQ(reward, 1);
}

void TabularQLearningTest_Catch_Lambda01FasterThanLambda00() {
// Eligibility traces (lambda > 0.0) always achieves victory with less
// training steps w.r.t. Q-learning(lambda=0.0)
std::shared_ptr<const Game> game = LoadGame("catch");
auto tabular_q_learning_solver_lambda00 = QLearningSolver(game, 0);
auto tabular_q_learning_solver_lambda01 = QLearningSolver(game, 0.1);

for (int seed = 0; seed < 100; seed++) {
int lambda_00_train_iter = 0;
int lambda_01_train_iter = 0;
double lambda_00_reward = -1.0;
double lambda_01_reward = -1.0;

while (lambda_00_reward == -1.0) {
tabular_q_learning_solver_lambda00->RunIteration();
std::unique_ptr<State> state = game->NewInitialState();
lambda_00_reward = PlayCatch(
tabular_q_learning_solver_lambda00->GetQValueTable(), state, seed);
lambda_00_train_iter++;
}
while (lambda_01_reward == -1.0) {
tabular_q_learning_solver_lambda01->RunIteration();
std::unique_ptr<State> state = game->NewInitialState();
lambda_01_reward = PlayCatch(
tabular_q_learning_solver_lambda01->GetQValueTable(), state, seed);
lambda_01_train_iter++;
}
SPIEL_CHECK_GE(lambda_00_train_iter, lambda_01_train_iter);
}
}

void TabularQLearningTest_TicTacToe_Lambda01_Win() {
std::shared_ptr<const Game> game = open_spiel::LoadGame("tic_tac_toe");
auto tabular_q_learning_solver = QLearningSolver(game, 0.1);

for (int i = 1; i < 100; i++) {
tabular_q_learning_solver->RunIteration();
}

const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_q_learning_solver->GetQValueTable();
std::unique_ptr<State> state = game->NewInitialState();

while (!state->IsTerminal()) {
Action random_action = GetRandomAction(state, 42);
state->ApplyAction(random_action); // player 0
if (random_action == kInvalidAction) break;
state->ApplyAction(GetOptimalAction(q_values, state)); // player 1
}

SPIEL_CHECK_EQ(state->Rewards()[0], -1);
}

void TabularQLearningTest_TicTacToe_Lambda01_Tie() {
std::shared_ptr<const Game> game = open_spiel::LoadGame("tic_tac_toe");
auto tabular_q_learning_solver = QLearningSolver(game, 0.1);

for (int i = 1; i < 1000; i++) {
tabular_q_learning_solver->RunIteration();
}

const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_q_learning_solver->GetQValueTable();
std::unique_ptr<State> state = game->NewInitialState();

while (!state->IsTerminal()) {
state->ApplyAction(GetOptimalAction(q_values, state));
}

SPIEL_CHECK_EQ(state->Rewards()[0], 0);
}

} // namespace
} // namespace open_spiel

int main(int argc, char **argv) {
open_spiel::TabularQLearningTest_Catch_Lambda00_Loss();
open_spiel::TabularQLearningTest_Catch_Lambda00_Win();
open_spiel::TabularQLearningTest_Catch_Lambda01_Win();
open_spiel::TabularQLearningTest_Catch_Lambda01FasterThanLambda00();
open_spiel::TabularQLearningTest_TicTacToe_Lambda01_Win();
open_spiel::TabularQLearningTest_TicTacToe_Lambda01_Tie();
}
Loading

0 comments on commit ca8affc

Please sign in to comment.