Skip to content

Commit

Permalink
Update 2d-meta_train.py
Browse files Browse the repository at this point in the history
Change some parameters for performance
  • Loading branch information
jieyibi authored Jun 19, 2024
1 parent d11788d commit 60fa8c8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/2d-meta_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():

# Example callbacks
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints", # save to checkpoints/
dirpath="meta_pomo/checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
Expand All @@ -47,8 +47,8 @@ def main():
# Meta callbacks
meta_callback = ReptileCallback(
num_tasks = 1, # the number of tasks in a mini-batch, i.e. `B` in the original paper
alpha = 0.99, # initial weight of the task model for the outer-loop optimization of reptile
alpha_decay = 0.999, # weight decay of the task model for the outer-loop optimization of reptile
alpha = 0.9, # initial weight of the task model for the outer-loop optimization of reptile
alpha_decay = 1, # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better.
min_size = 20, # minimum of sampled size in meta tasks (only supported in cross-size generalization)
max_size= 150, # maximum of sampled size in meta tasks (only supported in cross-size generalization)
data_type="size_distribution", # choose from ["size", "distribution", "size_distribution"]
Expand All @@ -63,7 +63,7 @@ def main():

# Adjust your trainer to the number of epochs you want to run
trainer = RL4COTrainer(
max_epochs=20000, # (the number of meta_model updates) * (the number of tasks in a mini-batch)
max_epochs=15000, # (the number of meta_model updates) * (the number of tasks in a mini-batch)
callbacks=callbacks,
accelerator="gpu",
devices=[device_id],
Expand Down

0 comments on commit 60fa8c8

Please sign in to comment.