Skip to content

Commit

Permalink
Update meta learning framework
Browse files Browse the repository at this point in the history
  • Loading branch information
jieyibi committed May 28, 2024
1 parent bff31a9 commit 0f4032c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 33 deletions.
16 changes: 7 additions & 9 deletions examples/2d-meta_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ def main():

# Meta callbacks
meta_callback = ReptileCallback(
meta_params={
'data_type': 'size', # choose from ["size", "distribution", "size_distribution"]
'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs
'B': 1, # the number of tasks in a mini-batch
'alpha': 0.99, # params for the outer-loop optimization of reptile
'alpha_decay': 0.999, # params for the outer-loop optimization of reptile
'min_size': 20, # minimum of sampled size in meta tasks
'max_size': 150, # maximum of sampled size in meta tasks
},
num_tasks = 1, # the number of tasks in a mini-batch
alpha = 0.99, # params for the outer-loop optimization of reptile
alpha_decay = 0.999, # params for the outer-loop optimization of reptile
min_size = 20, # minimum of sampled size in meta tasks
max_size= 150, # maximum of sampled size in meta tasks
data_type="size", # choose from ["size", "distribution", "size_distribution"]
sch_bar=0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs
print_log=True # whether to print the sampled tasks in each meta iteration
)
callbacks = [meta_callback, checkpoint_callback, rich_model_summary]
Expand Down
53 changes: 32 additions & 21 deletions rl4co/utils/meta_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,29 @@ class ReptileCallback(Callback):

# Meta training framework for addressing the generalization issue
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587
def __init__(self, meta_params, print_log=True):
def __init__(self,
num_tasks,
alpha,
alpha_decay,
min_size,
max_size,
sch_bar = 0.9,
data_type = "size",
print_log=True):
super().__init__()
self.meta_params = meta_params
self.print_log = print_log

def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

# Initialize some hyperparameters
self.alpha = self.meta_params["alpha"]
self.alpha_decay = self.meta_params["alpha_decay"]
self.sch_bar = self.meta_params["sch_bar"]
if self.meta_params["data_type"] == "size":
self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)]
self.num_tasks = num_tasks # i.e., B in the paper
self.alpha = alpha
self.alpha_decay = alpha_decay
self.sch_bar = sch_bar
self.print_log = print_log
if data_type == "size":
self.task_set = [(n,) for n in range(min_size, max_size + 1)]
else:
raise NotImplementedError

def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

# Sample a batch of tasks
self._sample_task()
self.selected_tasks[0] = (pl_module.env.generator.num_loc, )
Expand All @@ -40,12 +47,12 @@ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModul
self._alpha_scheduler()

# Reinitialize the task model with the parameters of the meta model
if trainer.current_epoch % self.meta_params['B'] == 0: # Save the meta model
if trainer.current_epoch % self.num_tasks == 0: # Save the meta model
self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict())
self.task_models = []
# Print sampled tasks
if self.print_log:
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.meta_params['B'], trainer.current_epoch, self.selected_tasks))
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.num_tasks, trainer.current_epoch, self.selected_tasks))
else:
pl_module.load_state_dict(self.meta_model_state_dict)

Expand All @@ -57,13 +64,16 @@ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModul

# Print
if self.print_log:
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity))
if hasattr(pl_module.env.generator, 'capacity'):
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity))
else:
print('\n>> Training task: {}'.format(pl_module.env.generator.num_loc))

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):

# Save the task model
self.task_models.append(copy.deepcopy(pl_module.state_dict()))
if (trainer.current_epoch+1) % self.meta_params['B'] == 0:
if (trainer.current_epoch+1) % self.num_tasks == 0:
# Outer-loop optimization (update the meta model with the parameters of the task model)
with torch.no_grad():
state_dict = {params_key: (self.meta_model_state_dict[params_key] +
Expand All @@ -73,27 +83,28 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule
pl_module.load_state_dict(state_dict)

# Get ready for the next meta-training iteration
if (trainer.current_epoch + 1) % self.meta_params['B'] == 0:
if (trainer.current_epoch + 1) % self.num_tasks == 0:
# Sample a batch of tasks
self._sample_task()

# Load new training task (Update the environment)
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.meta_params['B'])
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.num_tasks)

def _sample_task(self):
# Sample a batch of tasks
w, self.selected_tasks = [1.0] * self.meta_params['B'], []
for b in range(self.meta_params['B']):
w, self.selected_tasks = [1.0] * self.num_tasks, []
for b in range(self.num_tasks):
task_params = random.sample(self.task_set, 1)[0]
self.selected_tasks.append(task_params)
self.w = torch.softmax(torch.Tensor(w), dim=0)

def _load_task(self, pl_module: pl.LightningModule, task_idx=0):
# Load new training task (Update the environment)
task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item()
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20
pl_module.env.generator.num_loc = task_params[0]
pl_module.env.generator.capacity = task_capacity
if hasattr(pl_module.env.generator, 'capacity'):
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20
pl_module.env.generator.capacity = task_capacity

def _alpha_scheduler(self):
self.alpha = max(self.alpha * self.alpha_decay, 0.0001)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def test_pomo_reptile():
policy = AttentionModelPolicy(env_name=env.name, embed_dim=128,
num_encoder_layers=6, num_heads=8,
normalization="instance", use_graph_context=False)
model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10)
model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10, test_data_size=10)
meta_callback = ReptileCallback(
meta_params={'data_type': 'size', 'sch_bar': 0.9, 'B': 2, 'alpha': 0.99,
'alpha_decay': 0.999, 'min_size': 20, 'max_size': 50}
data_type="size", sch_bar=0.9, num_tasks=2, alpha = 0.99,
alpha_decay = 0.999, min_size = 20, max_size =50
)
trainer = RL4COTrainer(max_epochs=2, callbacks=[meta_callback], devices=1, accelerator=accelerator, limit_train_batches=3)
trainer.fit(model)
Expand Down

0 comments on commit 0f4032c

Please sign in to comment.