Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixture of Experts MCTS (MoE MCTS) #216

Merged
merged 57 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
f066178
- added game phase detection file
HelpstoneX Jun 22, 2023
e7b9f00
- changed openspiel git
HelpstoneX Jun 22, 2023
ebfa472
- changed openspiel git
HelpstoneX Jun 24, 2023
ae18d66
fixed phase ids
HelpstoneX Jun 24, 2023
8eb237f
added dataset creation option for specific phases
HelpstoneX Jun 24, 2023
738cd8d
Merge branch 'dataset_creation' of https://github.com/HelpstoneX/Craz…
HelpstoneX Jun 29, 2023
3b56cee
param changes in train_cnn.ipynb
HelpstoneX Jul 3, 2023
826f3b6
- fixed plys_to_end list to only include values for moves that really…
HelpstoneX Jul 4, 2023
b92a8a8
Merge branch 'dataset_creation'
HelpstoneX Jul 5, 2023
cdbe3ab
- changes to train_cnn to make it compatible
HelpstoneX Aug 3, 2023
ce89f69
updated fork
HelpstoneX Aug 10, 2023
e55c318
mcts phase integration working, some improvements missing
HelpstoneX Sep 14, 2023
a2a9a2c
- added phase_to_nets map to make sure the right net is used for each…
HelpstoneX Sep 20, 2023
0be67ff
- added game phase vector to created datasets
HelpstoneX Sep 27, 2023
9e9b009
minor fixes for weighted training
HelpstoneX Sep 27, 2023
1700818
- fixes and improvements to prs.py from cutechess-cli
HelpstoneX Sep 28, 2023
cb575ce
- changes for continuing training from tar file (pytorch)
HelpstoneX Oct 2, 2023
964a982
- added python file for training (exported notebook)
HelpstoneX Oct 2, 2023
c3a2a25
- added python file for executing cutechess shell commands
HelpstoneX Oct 3, 2023
d05228d
- added the option to specify additional eval sets (unweighted) to pa…
HelpstoneX Oct 7, 2023
87456f2
- minor changes
HelpstoneX Oct 7, 2023
4229a04
- minor changes for debugging
HelpstoneX Oct 7, 2023
e9bed77
- bugfix in train_cnn.py for additional dataloaders
HelpstoneX Oct 7, 2023
c13d0be
- bugfix in to correctly determine train iterations
HelpstoneX Oct 9, 2023
4a11bd8
- minor changese in prs.py
HelpstoneX Oct 14, 2023
614a4c9
- minor changes for chess 960
HelpstoneX Oct 14, 2023
5049a1e
- reverted mode and version back to 2 and 3
HelpstoneX Oct 14, 2023
d0cc5f4
fixed bug when executing isready multiple times consecutively while s…
HelpstoneX Oct 26, 2023
b47f44b
alternative bugfix attempt for linux
HelpstoneX Oct 27, 2023
57e9d7b
- temporary fix for chess960 wrong training representation
HelpstoneX Nov 1, 2023
982c650
- changes to incorporate 960 dataset analysis
HelpstoneX Dec 9, 2023
d18251b
chess960 input representation fix (c++ engine files still unadjusted …
HelpstoneX Dec 9, 2023
26cb920
- added plot generating notebooks to git (/etc folder)
HelpstoneX Dec 15, 2023
5e0df73
- added support for naive movecount phases
HelpstoneX Dec 17, 2023
f065280
- minor path fix in dataset_loader.py
HelpstoneX Dec 17, 2023
1a7ab6c
undone temporary fix for broken chess960 input representation
HelpstoneX Dec 19, 2023
69e002c
- added support for phases by movecount in c++ code (currently always…
HelpstoneX Jan 12, 2024
85167cd
- minor plotting adjustments
HelpstoneX Jan 17, 2024
2059aac
- adjusted run_cutechess_experiments.py to be able to do experiments …
HelpstoneX Mar 20, 2024
bf57009
- added documentation
HelpstoneX Apr 11, 2024
09c6810
- minor assertion change in train_cnn.py
HelpstoneX Apr 11, 2024
9d5fa1f
- cleaned code and removed sections that are not needed anymore
HelpstoneX Apr 20, 2024
336f785
- changed underscore naming to camelCase naming in several cases
HelpstoneX Apr 20, 2024
fa3dfe6
- added UCI option Game_Phase_Definition with options "lichess" and "…
HelpstoneX Apr 20, 2024
9c4d507
- added searchSettings to RawNetAgent to access selected gamePhaseDef…
HelpstoneX Apr 22, 2024
687ef8a
- aligned train_cnn.ipynb with code inside train_cnn.py
HelpstoneX Apr 22, 2024
65a302b
- cleaned cell outputs of main notebooks
HelpstoneX May 1, 2024
2fae83a
- further notebook output cleanings
HelpstoneX May 1, 2024
2917f78
- removed files unnecessary for pull request and reverted several fil…
HelpstoneX May 1, 2024
df97fac
- reverted .gitignore and Dockerfile to older state
HelpstoneX May 1, 2024
2fd46bb
- .gitignore update to different previous state
HelpstoneX May 1, 2024
6df0f5a
Merge branch 'master' into pull_request_preparation
QueensGambit May 2, 2024
4030367
Update crazyara.cpp
QueensGambit May 2, 2024
81e1a94
Update board.cpp
QueensGambit May 2, 2024
3cbb403
Add GamePhase get_phase to states
QueensGambit May 3, 2024
fabc465
Add GamePhase OpenSpielState::get_phase()
QueensGambit May 3, 2024
a4853f5
Update get_data_loader() to load dict instead
QueensGambit May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
- added UCI option Game_Phase_Definition with options "lichess" and "…
…movecount" and corresponding searchsettings enum GamePhaseDefinition
  • Loading branch information
HelpstoneX committed Apr 20, 2024
commit fa3dfe628206cb73bda13e1385bf4e4cb379ec07
3 changes: 2 additions & 1 deletion engine/src/agents/config/searchsettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ SearchSettings::SearchSettings():
searchPlayerMode(MODE_TWO_PLAYER),
virtualStyle(VIRTUAL_VISIT),
virtualMixThreshold(1000),
virtualOffsetStrenght(0.001)
virtualOffsetStrenght(0.001),
gamePhaseDefinition(MOVECOUNT)
{

}
7 changes: 7 additions & 0 deletions engine/src/agents/config/searchsettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ enum VirtualStyle {
VIRTUAL_MIX
};

enum GamePhaseDefinition {
LICHESS,
MOVECOUNT
};

struct SearchSettings
{
uint16_t multiPV;
Expand Down Expand Up @@ -87,6 +92,8 @@ struct SearchSettings
uint_fast32_t virtualMixThreshold;
// Defines the strength of the virtual offset
double virtualOffsetStrenght;
// Defines the type of game phase definition to be used
GamePhaseDefinition gamePhaseDefinition;
SearchSettings();

};
Expand Down
2 changes: 1 addition & 1 deletion engine/src/agents/mctsagent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ shared_ptr<Node> MCTSAgent::get_root_node_from_tree(StateObj *state)
void MCTSAgent::set_root_node_predictions()
{
state->get_state_planes(true, inputPlanes, nets.front()->get_version());
GamePhase currentPhase = state->get_phase(numPhases);
GamePhase currentPhase = state->get_phase(numPhases, searchSettings->gamePhaseDefinition);
nets[phaseToNetsIndex.at(currentPhase)]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
size_t tbHits = 0;
fill_nn_results(0, nets[phaseToNetsIndex.at(currentPhase)]->is_policy_map(), valueOutputs, probOutputs, auxiliaryOutputs, rootNode.get(), tbHits,
Expand Down
3 changes: 2 additions & 1 deletion engine/src/agents/rawnetagent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ void RawNetAgent::evaluate_board_state()
return;
}
state->get_state_planes(true, inputPlanes, nets.front()->get_version());
nets[phaseToNetsIndex.at(state->get_phase(numPhases))]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
// TODO: currently always uses MOVECOUNT as GamePhaseDefinition because RawNetAgent has no SearchSettings available
nets[phaseToNetsIndex.at(state->get_phase(numPhases, MOVECOUNT))]->predict(inputPlanes, valueOutputs, probOutputs, auxiliaryOutputs);
state->set_auxiliary_outputs(auxiliaryOutputs);

evalInfo->policyProbSmall.resize(evalInfo->legalMoves.size());
Expand Down
34 changes: 19 additions & 15 deletions engine/src/environments/chess_related/board.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,12 @@ int get_mixedness(const Board& pos)
return mix;
}

GamePhase Board::get_phase(unsigned int numPhases) const
GamePhase Board::get_phase(unsigned int numPhases, GamePhaseDefinition gamePhaseDefinition) const
{
if (numPhases == 3 && true) {
// currently enabled so that trivial phases by move count are not used if numPhases == 3
if (gamePhaseDefinition == LICHESS) {

assert(numPhases == 3); // lichess definition requires three models to be loaded

// returns the game phase based on the lichess definition implemented in:
// https://github.com/lichess-org/scalachess/blob/master/src/main/scala/Divider.scala
unsigned int numMajorsAndMinors = get_majors_and_minors_count(*this);
Expand All @@ -564,19 +566,21 @@ GamePhase Board::get_phase(unsigned int numPhases) const
}
}
}
if (numPhases == 1) {
return GamePhase(0);
}
else { // use naive phases by move count
double averageMovecountPerGame = 42.85;
double phaseLength = std::round(averageMovecountPerGame / numPhases);
size_t movesCompleted = this->total_move_cout();
double gamePhaseDouble = movesCompleted / phaseLength;
if (gamePhaseDouble > numPhases - 1){ // ensure that all higher results are attributed to the last phase
return GamePhase(numPhases - 1);
else if (gamePhaseDefinition == MOVECOUNT) {
if (numPhases == 1) { // directly return phase 0 if there is only a single network loaded
return GamePhase(0);
}
else {
return GamePhase(gamePhaseDouble); // truncated to Integer value
else { // use naive phases by move count
double averageMovecountPerGame = 42.85;
double phaseLength = std::round(averageMovecountPerGame / numPhases);
size_t movesCompleted = this->total_move_cout();
double gamePhaseDouble = movesCompleted / phaseLength;
if (gamePhaseDouble > numPhases - 1) { // ensure that all higher results are attributed to the last phase
return GamePhase(numPhases - 1);
}
else {
return GamePhase(gamePhaseDouble); // truncated to Integer value
}
}
}
}
5 changes: 3 additions & 2 deletions engine/src/environments/chess_related/board.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ class Board : public Position
bool draw_by_insufficient_material() const;

/**
* @brief get_phase Returns the game phase of the current board state based on the total amount of phases
* @brief get_phase Returns the game phase of the current board state based on the total amount of phases and the chosen GamePhaseDefinition
* Possible returned values are all integers from 0 to numPhases - 1
* @param unsigned int numPhases
* @param GamePhaseDefinition gamePhaseDefinition
* @return Game phase as unsigned int
*/
GamePhase get_phase(unsigned int numPhases) const;
GamePhase get_phase(unsigned int numPhases, GamePhaseDefinition gamePhaseDefinition) const;

// overloaded function which include a last move list update
void do_move(Move m, StateInfo& newSt);
Expand Down
4 changes: 2 additions & 2 deletions engine/src/environments/chess_related/boardstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ void BoardState::init(int variant, bool is960)

#endif

GamePhase BoardState::get_phase(unsigned int numPhases) const
GamePhase BoardState::get_phase(unsigned int numPhases, GamePhaseDefinition gamePhaseDefinition) const
{
return board.get_phase(numPhases);
return board.get_phase(numPhases, gamePhaseDefinition);
}
2 changes: 1 addition & 1 deletion engine/src/environments/chess_related/boardstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class BoardState : public State
void set_auxiliary_outputs(const float* auxiliaryOutputs) override;
BoardState* clone() const override;
void init(int variant, bool isChess960) override;
GamePhase get_phase(unsigned int numPhases) const;
GamePhase get_phase(unsigned int numPhases, GamePhaseDefinition gamePhaseDefinition) const;
};

#endif // BOARTSTATE_H
Expand Down
2 changes: 1 addition & 1 deletion engine/src/searchthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ Node* SearchThread::get_new_child_to_evaluate(NodeDescription& description)
// fill a new board in the input_planes vector
// we shift the index by nbNNInputValues each time
newState->get_state_planes(true, inputPlanes + newNodes->size() * nets.front()->get_nb_input_values_total(), nets.front()->get_version());
GamePhase currPhase = newState->get_phase(numPhases);
GamePhase currPhase = newState->get_phase(numPhases, searchSettings->gamePhaseDefinition);
phaseCountMap[currPhase]++;
// save a reference newly created list in the temporary list for node creation
// it will later be updated with the evaluation of the NN
Expand Down
1 change: 1 addition & 0 deletions engine/src/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <memory>
#include "version.h"
#include "util/communication.h"
#include "agents/config/searchsettings.h"

typedef uint64_t Key;
#ifdef ACTION_64_BIT
Expand Down
9 changes: 9 additions & 0 deletions engine/src/uci/crazyara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,15 @@ void CrazyAra::init_search_settings()
info_string_important("Unknown option", Options["Virtual_Style"], "for Virtual_Style");
}
searchSettings.virtualMixThreshold = Options["Virtual_Mix_Threshold"];
if (Options["Game_Phase_Definition"] == "lichess") {
searchSettings.gamePhaseDefinition = LICHESS;
}
else if (Options["Game_Phase_Definition"] == "movecount") {
searchSettings.gamePhaseDefinition = MOVECOUNT;
}
else {
info_string_important("Unknown option", Options["Game_Phase_Definition"], "for Game_Phase_Definition");
}
}

void CrazyAra::init_play_settings()
Expand Down
1 change: 1 addition & 0 deletions engine/src/uci/optionsuci.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ void OptionsUCI::init(OptionsMap &o)
o["Use_Raw_Network"] << Option(false);
o["Virtual_Style"] << Option("virtual_mix", { "virtual_loss", "virtual_visit", "virtual_offset", "virtual_mix" });
o["Virtual_Mix_Threshold"] << Option(1000, 1, 99999999);
o["Game_Phase_Definition"] << Option("movecount", { "lichess", "movecount"});
// additional UCI-Options for RL only
#ifdef USE_RL
o["Centi_Node_Random_Factor"] << Option(10, 0, 100);
Expand Down