-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import sys | ||
sys.path.append("/home/jieyi/rl4co") | ||
|
||
import pytz | ||
import torch | ||
|
||
from datetime import datetime | ||
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary | ||
from lightning.pytorch.loggers import WandbLogger | ||
|
||
from rl4co.envs import CVRPEnv | ||
from rl4co.models.zoo.am import AttentionModelPolicy | ||
from rl4co.models.zoo.pomo import POMO | ||
from rl4co.utils.meta_trainer import RL4COMetaTrainer, MetaModelCallback | ||
|
||
def main(): | ||
# Set device | ||
device_id = 0 | ||
|
||
# RL4CO env based on TorchRL | ||
env = CVRPEnv(generator_params={'num_loc': 50}) | ||
|
||
# Policy: neural network, in this case with encoder-decoder architecture | ||
# Note that this is adapted the same as POMO did in the original paper | ||
policy = AttentionModelPolicy(env_name=env.name, | ||
embed_dim=128, | ||
num_encoder_layers=6, | ||
num_heads=8, | ||
normalization="instance", | ||
use_graph_context=False | ||
) | ||
|
||
# RL Model (POMO) | ||
model = POMO(env, | ||
policy, | ||
batch_size=64, # meta_batch_size | ||
train_data_size=64 * 50, # each epoch | ||
val_data_size=0, | ||
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, | ||
# for the task scheduler of size setting, where sch_epoch = 0.9 * epochs | ||
) | ||
|
||
# Example callbacks | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="checkpoints", # save to checkpoints/ | ||
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt | ||
save_top_k=1, # save only the best model | ||
save_last=True, # save the last model | ||
monitor="val/reward", # monitor validation reward | ||
mode="max", # maximize validation reward | ||
) | ||
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback | ||
# Meta callbacks | ||
meta_callback = MetaModelCallback( | ||
meta_params={ | ||
'meta_method': 'reptile', # choose from ['maml', 'fomaml', 'maml_fomaml', 'reptile'] | ||
'data_type': 'size', # choose from ["size", "distribution", "size_distribution"] | ||
'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs | ||
'B': 1, # the number of tasks in a mini-batch | ||
'alpha': 0.99, # params for the outer-loop optimization of reptile | ||
'alpha_decay': 0.999, # params for the outer-loop optimization of reptile | ||
'min_size': 20, # minimum of sampled size in meta tasks | ||
'max_size': 150, # maximum of sampled size in meta tasks | ||
}, | ||
print_log=True # whether to print the sampled tasks in each meta iteration | ||
) | ||
callbacks = [meta_callback, checkpoint_callback, rich_model_summary] | ||
|
||
# Logger | ||
process_start_time = datetime.now(pytz.timezone("Asia/Singapore")) | ||
logger = WandbLogger(project="rl4co", name=f"{env.name}_{process_start_time.strftime('%Y%m%d_%H%M%S')}") | ||
# logger = None # uncomment this line if you don't want logging | ||
|
||
# Adjust your trainer to the number of epochs you want to run | ||
trainer = RL4COMetaTrainer( | ||
max_epochs=20000, # (the number of meta-model updates) * (the number of tasks in a mini-batch) | ||
callbacks=callbacks, | ||
accelerator="gpu", | ||
devices=[device_id], | ||
logger=logger, | ||
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method | ||
) | ||
|
||
# Fit | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
from typing import Iterable, List, Optional, Union | ||
|
||
import lightning.pytorch as pl | ||
import torch | ||
import math | ||
import copy | ||
from torch.optim import Adam | ||
|
||
from lightning import Callback, Trainer | ||
from lightning.fabric.accelerators.cuda import num_cuda_devices | ||
from lightning.pytorch.accelerators import Accelerator | ||
from lightning.pytorch.core.datamodule import LightningDataModule | ||
from lightning.pytorch.loggers import Logger | ||
from lightning.pytorch.strategies import DDPStrategy, Strategy | ||
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | ||
from rl4co import utils | ||
import random | ||
log = utils.get_pylogger(__name__) | ||
|
||
|
||
class MetaModelCallback(Callback): | ||
def __init__(self, meta_params, print_log=True): | ||
super().__init__() | ||
self.meta_params = meta_params | ||
assert meta_params["meta_method"] == 'reptile', NotImplementedError | ||
assert meta_params["data_type"] == 'size', NotImplementedError | ||
self.print_log = print_log | ||
|
||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
|
||
# Initialize some hyperparameters | ||
self.alpha = self.meta_params["alpha"] | ||
self.alpha_decay = self.meta_params["alpha_decay"] | ||
self.sch_bar = self.meta_params["sch_bar"] | ||
self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)] | ||
|
||
# Sample a batch of tasks | ||
self._sample_task() | ||
self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) | ||
|
||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
|
||
# Alpha scheduler (decay for the update of meta model) | ||
self._alpha_scheduler() | ||
|
||
# Reinitialize the task model with the parameters of the meta model | ||
if trainer.current_epoch % self.meta_params['B'] == 0: # Save the meta model | ||
self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict()) | ||
self.task_models = [] | ||
# Print sampled tasks | ||
if self.print_log: | ||
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.meta_params['B'], trainer.current_epoch, self.selected_tasks)) | ||
else: | ||
pl_module.load_state_dict(self.meta_model_state_dict) | ||
|
||
# Reinitialize the optimizer every epoch | ||
lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1 | ||
old_lr = trainer.optimizers[0].param_groups[0]['lr'] | ||
new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay) | ||
trainer.optimizers = [new_optimizer] | ||
|
||
if self.print_log: | ||
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) | ||
|
||
def on_train_epoch_end(self, trainer, pl_module): | ||
|
||
# Save the task model | ||
self.task_models.append(copy.deepcopy(pl_module.state_dict())) | ||
if (trainer.current_epoch+1) % self.meta_params['B'] == 0: | ||
# Outer-loop optimization (update the meta model with the parameters of the task model) | ||
with torch.no_grad(): | ||
state_dict = {params_key: (self.meta_model_state_dict[params_key] + | ||
self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key] | ||
for fast_weight in self.task_models], dim=0).float(), dim=0)) | ||
for params_key in self.meta_model_state_dict} | ||
pl_module.load_state_dict(state_dict) | ||
|
||
# Get ready for the next meta-training iteration | ||
if (trainer.current_epoch + 1) % self.meta_params['B'] == 0: | ||
# Sample a batch of tasks | ||
self._sample_task() | ||
|
||
# Load new training task (Update the environment) | ||
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.meta_params['B']) | ||
|
||
def _sample_task(self): | ||
# Sample a batch of tasks | ||
w, self.selected_tasks = [1.0] * self.meta_params['B'], [] | ||
for b in range(self.meta_params['B']): | ||
task_params = random.sample(self.task_set, 1)[0] | ||
self.selected_tasks.append(task_params) | ||
self.w = torch.softmax(torch.Tensor(w), dim=0) | ||
|
||
def _load_task(self, pl_module, task_idx=0): | ||
# Load new training task (Update the environment) | ||
task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() | ||
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 | ||
pl_module.env.generator.num_loc = task_params[0] | ||
pl_module.env.generator.capacity = task_capacity | ||
|
||
def _alpha_scheduler(self): | ||
self.alpha = max(self.alpha * self.alpha_decay, 0.0001) | ||
|
||
class RL4COMetaTrainer(Trainer): | ||
"""Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. | ||
Note: | ||
The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. | ||
This allows for datasets to be re-created on the run and distributed by Lightning across | ||
devices on each epoch. Setting to a value different than 1 may lead to overfitting to a | ||
specific (such as the initial) data distribution. | ||
Args: | ||
accelerator: hardware accelerator to use. | ||
callbacks: list of callbacks. | ||
logger: logger (or iterable collection of loggers) for experiment tracking. | ||
min_epochs: minimum number of training epochs. | ||
max_epochs: maximum number of training epochs. | ||
strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP). | ||
devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node. | ||
gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability. | ||
precision: allows for mixed precision training. Can be specified as a string (e.g., '16'). | ||
This also allows to use `FlashAttention` by default. | ||
disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. | ||
auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. | ||
reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. | ||
matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision | ||
**kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
accelerator: Union[str, Accelerator] = "auto", | ||
callbacks: Optional[List[Callback]] = None, | ||
logger: Optional[Union[Logger, Iterable[Logger]]] = None, | ||
min_epochs: Optional[int] = None, | ||
max_epochs: Optional[int] = None, | ||
strategy: Union[str, Strategy] = "auto", | ||
devices: Union[List[int], str, int] = "auto", | ||
gradient_clip_val: Union[int, float] = 1.0, | ||
precision: Union[str, int] = "16-mixed", | ||
reload_dataloaders_every_n_epochs: int = 1, | ||
disable_profiling_executor: bool = True, | ||
auto_configure_ddp: bool = True, | ||
matmul_precision: Union[str, int] = "medium", | ||
**kwargs, | ||
): | ||
# Disable JIT profiling executor. This reduces memory and increases speed. | ||
# Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 | ||
if disable_profiling_executor: | ||
try: | ||
torch._C._jit_set_profiling_executor(False) | ||
torch._C._jit_set_profiling_mode(False) | ||
except AttributeError: | ||
pass | ||
|
||
# Configure DDP automatically if multiple GPUs are available | ||
if auto_configure_ddp and strategy == "auto": | ||
if devices == "auto": | ||
n_devices = num_cuda_devices() | ||
elif isinstance(devices, Iterable): | ||
n_devices = len(devices) | ||
else: | ||
n_devices = devices | ||
if n_devices > 1: | ||
log.info( | ||
"Configuring DDP strategy automatically with {} GPUs".format( | ||
n_devices | ||
) | ||
) | ||
strategy = DDPStrategy( | ||
find_unused_parameters=True, # We set to True due to RL envs | ||
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations | ||
) | ||
|
||
# Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision | ||
if matmul_precision is not None: | ||
torch.set_float32_matmul_precision(matmul_precision) | ||
|
||
# Check if gradient_clip_val is set to None | ||
if gradient_clip_val is None: | ||
log.warning( | ||
"gradient_clip_val is set to None. This may lead to unstable training." | ||
) | ||
|
||
# We should reload dataloaders every epoch for RL training | ||
if reload_dataloaders_every_n_epochs != 1: | ||
log.warning( | ||
"We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " | ||
+ "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." | ||
) | ||
|
||
# Main call to `Trainer` superclass | ||
super().__init__( | ||
accelerator=accelerator, | ||
callbacks=callbacks, | ||
logger=logger, | ||
min_epochs=min_epochs, | ||
max_epochs=max_epochs, | ||
strategy=strategy, | ||
gradient_clip_val=gradient_clip_val, | ||
devices=devices, | ||
precision=precision, | ||
reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, | ||
**kwargs, | ||
) | ||
|
||
def fit( | ||
self, | ||
model: "pl.LightningModule", | ||
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, | ||
val_dataloaders: Optional[EVAL_DATALOADERS] = None, | ||
datamodule: Optional[LightningDataModule] = None, | ||
ckpt_path: Optional[str] = None, | ||
) -> None: | ||
""" | ||
We override the `fit` method to automatically apply and handle RL4CO magic | ||
to 'self.automatic_optimization = False' models, such as PPO | ||
It behaves exactly like the original `fit` method, but with the following changes: | ||
- if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None | ||
""" | ||
|
||
if not model.automatic_optimization: | ||
if self.gradient_clip_val is not None: | ||
log.warning( | ||
"Overriding gradient_clip_val to None for 'automatic_optimization=False' models" | ||
) | ||
self.gradient_clip_val = None | ||
|
||
# Fit (Inner-loop Optimization) | ||
super().fit( | ||
model=model, | ||
train_dataloaders=train_dataloaders, | ||
val_dataloaders=val_dataloaders, | ||
datamodule=datamodule, | ||
ckpt_path=ckpt_path, | ||
) | ||
|
||
|
||
|