Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TdLambdaReturns for alpha_zero_torch #940

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions open_spiel/algorithms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ add_executable(matrix_game_utils_test matrix_game_utils_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(matrix_game_utils_test matrix_game_utils_test)

add_executable(mcts_test mcts_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(mcts_test mcts_test)

add_executable(minimax_test minimax_test.cc
$<TARGET_OBJECTS:algorithms> ${OPEN_SPIEL_OBJECTS})
add_test(minimax_test minimax_test)
Expand Down
2 changes: 2 additions & 0 deletions open_spiel/algorithms/alpha_zero/alpha_zero.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ std::unique_ptr<MCTSBot> InitAZBot(
game,
std::move(evaluator),
config.uct_c,
config.min_simulations,
config.max_simulations,
/*max_memory_mb=*/ 10,
/*solve=*/ false,
Expand Down Expand Up @@ -231,6 +232,7 @@ void evaluator(const open_spiel::Game& game, const AlphaZeroConfig& config,
game,
rand_evaluator,
config.uct_c,
/*min_simulations=*/0,
rand_max_simulations,
/*max_memory_mb=*/1000,
/*solve=*/true,
Expand Down
2 changes: 2 additions & 0 deletions open_spiel/algorithms/alpha_zero/alpha_zero.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct AlphaZeroConfig {
int evaluation_window;

double uct_c;
int min_simulations;
int max_simulations;
double policy_alpha;
double policy_epsilon;
Expand Down Expand Up @@ -74,6 +75,7 @@ struct AlphaZeroConfig {
{"checkpoint_freq", checkpoint_freq},
{"evaluation_window", evaluation_window},
{"uct_c", uct_c},
{"min_simulations", min_simulations},
{"max_simulations", max_simulations},
{"policy_alpha", policy_alpha},
{"policy_epsilon", policy_epsilon},
Expand Down
5 changes: 5 additions & 0 deletions open_spiel/algorithms/alpha_zero_torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ if (OPEN_SPIEL_BUILD_WITH_LIBTORCH)
)
target_include_directories (alpha_zero_torch PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

add_executable(torch_alpha_zero_test alpha_zero_test.cc ${OPEN_SPIEL_OBJECTS}
$<TARGET_OBJECTS:alpha_zero_torch> $<TARGET_OBJECTS:tests>)
add_test(torch_alpha_zero_test torch_alpha_zero_test)

add_executable(torch_model_test model_test.cc ${OPEN_SPIEL_OBJECTS}
$<TARGET_OBJECTS:alpha_zero_torch> $<TARGET_OBJECTS:tests>)
add_test(torch_model_test torch_model_test)
Expand All @@ -27,6 +31,7 @@ if (OPEN_SPIEL_BUILD_WITH_LIBTORCH)
add_test(torch_vpnet_test torch_vpnet_test)

target_link_libraries (alpha_zero_torch ${TORCH_LIBRARIES})
target_link_libraries (torch_alpha_zero_test ${TORCH_LIBRARIES})
target_link_libraries (torch_model_test ${TORCH_LIBRARIES})
target_link_libraries (torch_vpnet_test ${TORCH_LIBRARIES})
endif ()
98 changes: 82 additions & 16 deletions open_spiel/algorithms/alpha_zero_torch/alpha_zero.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,30 +125,43 @@ Trajectory PlayGame(Logger* logger, int game_num, const open_spiel::Game& game,
std::pow(c.explore_count, 1.0 / temperature));
}
NormalizePolicy(&policy);
open_spiel::Action action;
const SearchNode* action_node;
if (history.size() >= temperature_drop) {
action = root->BestChild().action;
action_node = &root->BestChild();
} else {
open_spiel::Action action;
action = open_spiel::SampleAction(policy, *rng).first;
for (const SearchNode& child : root->children) {
if (child.action == action) {
action_node = &child;
break;
}
}
}

double root_value = root->total_reward / root->explore_count;
double action_value =
action_node->outcome.empty()
? (action_node->total_reward / action_node->explore_count)
* (action_node->player == player ? 1 : -1)
: action_node->outcome[player];
trajectory.states.push_back(Trajectory::State{
state->ObservationTensor(), player, state->LegalActions(), action,
std::move(policy), root_value});
std::string action_str = state->ActionToString(player, action);
state->ObservationTensor(), player, state->LegalActions(),
action_node->action, std::move(policy), action_value});
std::string action_str =
state->ActionToString(player, action_node->action);
history.push_back(action_str);
state->ApplyAction(action);
state->ApplyAction(action_node->action);
if (verbose) {
logger->Print("Player: %d, action: %s", player, action_str);
logger->Print("Player: %d, action: %s, value: %6.3f",
player, action_str, action_value);
}
if (state->IsTerminal()) {
trajectory.returns = state->Returns();
break;
} else if (std::abs(root_value) > cutoff_value) {
} else if (std::abs(action_value) > cutoff_value) {
trajectory.returns.resize(2);
trajectory.returns[player] = root_value;
trajectory.returns[1 - player] = -root_value;
trajectory.returns[player] = action_value;
trajectory.returns[1 - player] = -action_value;
break;
}
}
Expand All @@ -165,7 +178,8 @@ std::unique_ptr<MCTSBot> InitAZBot(const AlphaZeroConfig& config,
std::shared_ptr<Evaluator> evaluator,
bool evaluation) {
return std::make_unique<MCTSBot>(
game, std::move(evaluator), config.uct_c, config.max_simulations,
game, std::move(evaluator), config.uct_c,
config.min_simulations, config.max_simulations,
/*max_memory_mb=*/10,
/*solve=*/false,
/*seed=*/0,
Expand Down Expand Up @@ -269,7 +283,8 @@ void evaluator(const open_spiel::Game& game, const AlphaZeroConfig& config,
bots.reserve(2);
bots.push_back(InitAZBot(config, game, vp_eval, true));
bots.push_back(std::make_unique<MCTSBot>(
game, rand_evaluator, config.uct_c, rand_max_simulations,
game, rand_evaluator, config.uct_c,
/*min_simulations=*/0, rand_max_simulations,
/*max_memory_mb=*/1000,
/*solve=*/true,
/*seed=*/num * 1000 + game_num,
Expand All @@ -295,12 +310,54 @@ void evaluator(const open_spiel::Game& game, const AlphaZeroConfig& config,
logger.Print("Got a quit.");
}

// Returns the 'lambda' discounted value of all future values of 'trajectory',
// including its outcome, beginning at 'state_idx'. The calculation is
// truncated after 'td_n_steps' if that parameter is greater than zero.
double TdLambdaReturns(const Trajectory& trajectory, int state_idx,
double td_lambda, int td_n_steps) {
double outcome = trajectory.returns[0];
if (td_lambda >= 1.0 || Near(td_lambda, 1.0)) {
// lambda == 1.0 simplifies to returning the outcome (or value at nth-step)
if (td_n_steps <= 0) {
return outcome;
}
int idx = state_idx + td_n_steps;
if (idx >= trajectory.states.size()) {
return outcome;
}
const Trajectory::State& n_state = trajectory.states[idx];
return n_state.value * (n_state.current_player == 0 ? 1 : -1);
}
const Trajectory::State& s_state = trajectory.states[state_idx];
double retval = s_state.value * (s_state.current_player == 0 ? 1 : -1);
if (td_lambda <= 0.0 || Near(td_lambda, 0.0)) {
// lambda == 0 simplifies to returning the start state's value
return retval;
}
double lambda_inv = (1.0 - td_lambda);
double lambda_pow = td_lambda;
retval *= lambda_inv;
for (int i = state_idx + 1; i < trajectory.states.size(); ++i) {
const Trajectory::State& i_state = trajectory.states[i];
double value = i_state.value * (i_state.current_player == 0 ? 1 : -1);
if (td_n_steps > 0 && i == state_idx + td_n_steps) {
retval += lambda_pow * value;
return retval;
}
retval += lambda_inv * lambda_pow * value;
lambda_pow *= td_lambda;
}
retval += lambda_pow * outcome;
return retval;
}

void learner(const open_spiel::Game& game, const AlphaZeroConfig& config,
DeviceManager* device_manager,
std::shared_ptr<VPNetEvaluator> eval,
ThreadedQueue<Trajectory>* trajectory_queue,
EvalResults* eval_results, StopToken* stop,
const StartInfo& start_info) {
const StartInfo& start_info,
bool verbose = false) {
FileLogger logger(config.path, "learner", "a");
DataLoggerJsonLines data_logger(
config.path, "learner", true, "a", start_info.start_time);
Expand Down Expand Up @@ -357,10 +414,19 @@ void learner(const open_spiel::Game& game, const AlphaZeroConfig& config,
double p1_outcome = trajectory->returns[0];
outcomes.Add(p1_outcome > 0 ? 0 : (p1_outcome < 0 ? 1 : 2));

for (const Trajectory::State& state : trajectory->states) {
for (int i = 0; i < trajectory->states.size(); ++i ) {
const Trajectory::State& state = trajectory->states[i];
double value = TdLambdaReturns(*trajectory, i,
config.td_lambda, config.td_n_steps);
replay_buffer.Add(VPNetModel::TrainInputs{state.legal_actions,
state.observation,
state.policy, p1_outcome});
state.policy,
value});
if (verbose && num_trajectories == 1) {
double v = state.value * (state.current_player == 0 ? 1 : -1);
logger.Print("StateIdx: %d Value: %0.3f TrainTo: %0.3f",
i, v, value);
}
num_states += 1;
}

Expand Down
75 changes: 44 additions & 31 deletions open_spiel/algorithms/alpha_zero_torch/alpha_zero.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ struct AlphaZeroConfig {
int evaluation_window;

double uct_c;
int min_simulations;
int max_simulations;
double policy_alpha;
double policy_epsilon;
double temperature;
double temperature_drop;
double cutoff_probability;
double cutoff_value;
double td_lambda;
int td_n_steps;

int actors;
int evaluators;
Expand Down Expand Up @@ -83,51 +86,61 @@ struct AlphaZeroConfig {
{"checkpoint_freq", checkpoint_freq},
{"evaluation_window", evaluation_window},
{"uct_c", uct_c},
{"min_simulations", min_simulations},
{"max_simulations", max_simulations},
{"policy_alpha", policy_alpha},
{"policy_epsilon", policy_epsilon},
{"temperature", temperature},
{"temperature_drop", temperature_drop},
{"cutoff_probability", cutoff_probability},
{"cutoff_value", cutoff_value},
{"td_lambda", td_lambda},
{"td_n_steps", td_n_steps},
{"actors", actors},
{"evaluators", evaluators},
{"eval_levels", eval_levels},
{"max_steps", max_steps},
});
}

void FromJson(const json::Object& config_json) {
game = config_json.at("game").GetString();
path = config_json.at("path").GetString();
graph_def = config_json.at("graph_def").GetString();
nn_model = config_json.at("nn_model").GetString();
nn_width = config_json.at("nn_width").GetInt();
nn_depth = config_json.at("nn_depth").GetInt();
devices = config_json.at("devices").GetString();
explicit_learning = config_json.at("explicit_learning").GetBool();
learning_rate = config_json.at("learning_rate").GetDouble();
weight_decay = config_json.at("weight_decay").GetDouble();
train_batch_size = config_json.at("train_batch_size").GetInt();
inference_batch_size = config_json.at("inference_batch_size").GetInt();
inference_threads = config_json.at("inference_threads").GetInt();
inference_cache = config_json.at("inference_cache").GetInt();
replay_buffer_size = config_json.at("replay_buffer_size").GetInt();
replay_buffer_reuse = config_json.at("replay_buffer_reuse").GetInt();
checkpoint_freq = config_json.at("checkpoint_freq").GetInt();
evaluation_window = config_json.at("evaluation_window").GetInt();
uct_c = config_json.at("uct_c").GetDouble();
max_simulations = config_json.at("max_simulations").GetInt();
policy_alpha = config_json.at("policy_alpha").GetDouble();
policy_epsilon = config_json.at("policy_epsilon").GetDouble();
temperature = config_json.at("temperature").GetDouble();
temperature_drop = config_json.at("temperature_drop").GetDouble();
cutoff_probability = config_json.at("cutoff_probability").GetDouble();
cutoff_value = config_json.at("cutoff_value").GetDouble();
actors = config_json.at("actors").GetInt();
evaluators = config_json.at("evaluators").GetInt();
eval_levels = config_json.at("eval_levels").GetInt();
max_steps = config_json.at("max_steps").GetInt();
void FromJsonWithDefaults(const json::Object& config_json,
const json::Object& defaults_json) {
json::Object merged;
merged.insert(config_json.begin(), config_json.end());
merged.insert(defaults_json.begin(), defaults_json.end());
game = merged.at("game").GetString();
path = merged.at("path").GetString();
graph_def = merged.at("graph_def").GetString();
nn_model = merged.at("nn_model").GetString();
nn_width = merged.at("nn_width").GetInt();
nn_depth = merged.at("nn_depth").GetInt();
devices = merged.at("devices").GetString();
explicit_learning = merged.at("explicit_learning").GetBool();
learning_rate = merged.at("learning_rate").GetDouble();
weight_decay = merged.at("weight_decay").GetDouble();
train_batch_size = merged.at("train_batch_size").GetInt();
inference_batch_size = merged.at("inference_batch_size").GetInt();
inference_threads = merged.at("inference_threads").GetInt();
inference_cache = merged.at("inference_cache").GetInt();
replay_buffer_size = merged.at("replay_buffer_size").GetInt();
replay_buffer_reuse = merged.at("replay_buffer_reuse").GetInt();
checkpoint_freq = merged.at("checkpoint_freq").GetInt();
evaluation_window = merged.at("evaluation_window").GetInt();
uct_c = merged.at("uct_c").GetDouble();
min_simulations = merged.at("min_simulations").GetInt();
max_simulations = merged.at("max_simulations").GetInt();
policy_alpha = merged.at("policy_alpha").GetDouble();
policy_epsilon = merged.at("policy_epsilon").GetDouble();
temperature = merged.at("temperature").GetDouble();
temperature_drop = merged.at("temperature_drop").GetDouble();
cutoff_probability = merged.at("cutoff_probability").GetDouble();
cutoff_value = merged.at("cutoff_value").GetDouble();
td_lambda = merged.at("td_lambda").GetDouble();
td_n_steps = merged.at("td_n_steps").GetInt();
actors = merged.at("actors").GetInt();
evaluators = merged.at("evaluators").GetInt();
eval_levels = merged.at("eval_levels").GetInt();
max_steps = merged.at("max_steps").GetInt();
}
};

Expand Down
Loading