This module provides an alternative method to facilitate training and is further used for large-scale distributed training and fine-tuning.
(TL;DR in case you skipped Quick Start, you can clone
this Google Colab file to start
using this trainer
module)
Using the gym environment through RL implementations like stable-baselines3
is a great way to train RL models.
Attempts are made but codes are not published along with the repo (very dirty and messy, but may be available upon
requests). But there are a few points which make it not a horrible idea to recreate a custom training facility:
- Our
gym
environment (gym_env
) usesListPoints
and runs one game at a time. But the states can be recorded as aTensor
and operations can be vectorized, which allows us to run multiple games in GPU. This might also be doable ingym
, but I foresee many potentially messy wrappers to reformat data. stable-baselines3
is a very professionally written project with an amazing execution of OOP principles. But everything comes with a trade-off. For example, when tracking and customizing a small segment of the procedure, one has to go through multiple subclasses from different base classes scattered across the folders. Modularity is certainly a great engineering concept, but a fused and streamlined structure may have roles to play when horizontal scaling is relatively unimportant.- We are running our codes on a massive GPU cluster consisting of thousands of nodes each of which has 8 Nvidia A100. Distributed training would have a lot of meanings in this context (distributed hyperparameter searching, distributed training of one single network, distributed model selection like genetic algorithm, etc.), which will eventually require a lot of highly customized codes.
Hence, I made this submodule.
Note that it is not my goal to make a general purpose RL trainer. I separate host- and agent-related attributes instead
of calling them by index (e.g. like self.player_net[0], self.player_net[1]
) due to the asymmetric nature of the game.
We may want to do different things for host and agent specifically.
But it should be really easy to adapt this code base to different symmetric games.
The classes center around Trainer
. Every trainer must be initialized with an exhaustive configuration dict without
having hidden parameters defaulting to certain values. This way, users are always reminded about everything that goes
into training.
An example config file in the format of YAML is given for each Trainer
subclass.
Trainer
provides the facility of our training. To implement a specific way of training, please:
- create a class that inherits
Trainer
, - read the docstring on what must be implemented and what can be overridden.
- implement the training logic in
_train()
based on what the class provides.
We hope that a subclass implementation is a highly concentrated piece of codes that only concerns with the training logic of the specific RL algorithm. This way, reading and tweaking all happens in one single place without getting the head buried under piles of folders and trees of inheritance.
What the class provides include but are not limited to:
- Key objects:
Trainer.host_net, Trainer.agent_net
: host and agent networks.Trainer.host_optimizer, Trainer.agent_optimizer
: host and agent optimizers.Trainer.host_lr_scheduler, Trainer.agent_lr_scheduler
: host and agent learning-rate schedulers.Trainer.host_er_scheduler, Trainer.agent_er_scheduler
: host and agent exploration-rate schedulers.Trainer.host_replay_buffer, Trainer.agent_replay_buffer
: host and agent replay buffers.Trainer.fused_game
: a FusedGame object for roll-out. Based on the host and agent network.
- Key methods:
Trainer.set_training()
: set training mode toTrue
orFalse
. It impacts layers in networks likeBatchNorm1d
.Trainer._roll_out()
: create roll-outs in the form of experiences (obs, actions, rewards, dones, next_obs).
Performs a classic DDQN.
FusedGame
is a class that fuses together everything about the gameplay. It avoids gym
and our Agent, Host, Game
,
etc. All the data is process as torch.Tensor
.
ReplayBuffer
is a very simple implementation of replay buffer.
Scheduler
is basically a function that takes the number of steps as an input. The only extras are that it persists
custom data.