-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #183 from jieyibi/main
Add meta learning framework
- Loading branch information
Showing
5 changed files
with
531 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary | ||
from lightning.pytorch.loggers import WandbLogger | ||
|
||
from rl4co.envs import CVRPEnv | ||
from rl4co.models.zoo.am import AttentionModelPolicy | ||
from rl4co.models.zoo.pomo import POMO | ||
from rl4co.utils.trainer import RL4COTrainer | ||
from rl4co.utils.meta_trainer import ReptileCallback | ||
|
||
def main(): | ||
# Set device | ||
device_id = 0 | ||
|
||
# RL4CO env based on TorchRL | ||
env = CVRPEnv(generator_params={'num_loc': 50}) | ||
|
||
# Policy: neural network, in this case with encoder-decoder architecture | ||
# Note that this is adapted the same as POMO did in the original paper | ||
policy = AttentionModelPolicy(env_name=env.name, | ||
embed_dim=128, | ||
num_encoder_layers=6, | ||
num_heads=8, | ||
normalization="instance", | ||
use_graph_context=False | ||
) | ||
|
||
# RL Model (POMO) | ||
model = POMO(env, | ||
policy, | ||
batch_size=64, # meta_batch_size | ||
train_data_size=64 * 50, # equals to (meta_batch_size) * (gradient decent steps in the inner-loop optimization of meta-learning method) | ||
val_data_size=0, | ||
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, | ||
) | ||
|
||
# Example callbacks | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="meta_pomo/checkpoints", # save to checkpoints/ | ||
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt | ||
save_top_k=1, # save only the best model | ||
save_last=True, # save the last model | ||
monitor="val/reward", # monitor validation reward | ||
mode="max", # maximize validation reward | ||
) | ||
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback | ||
|
||
# Meta callbacks | ||
meta_callback = ReptileCallback( | ||
num_tasks = 1, # the number of tasks in a mini-batch, i.e. `B` in the original paper | ||
alpha = 0.9, # initial weight of the task model for the outer-loop optimization of reptile | ||
alpha_decay = 1, # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better. | ||
min_size = 20, # minimum of sampled size in meta tasks (only supported in cross-size generalization) | ||
max_size= 150, # maximum of sampled size in meta tasks (only supported in cross-size generalization) | ||
data_type="size_distribution", # choose from ["size", "distribution", "size_distribution"] | ||
sch_bar=0.9, # for the task scheduler of size setting, where lr_decay_epoch = sch_bar * epochs, i.e. after this epoch, learning rate will decay with a weight 0.1 | ||
print_log=True # whether to print the sampled tasks in each meta iteration | ||
) | ||
callbacks = [meta_callback, checkpoint_callback, rich_model_summary] | ||
|
||
# Logger | ||
logger = WandbLogger(project="rl4co", name=f"{env.name}_pomo_reptile") | ||
# logger = None # uncomment this line if you don't want logging | ||
|
||
# Adjust your trainer to the number of epochs you want to run | ||
trainer = RL4COTrainer( | ||
max_epochs=15000, # (the number of meta_model updates) * (the number of tasks in a mini-batch) | ||
callbacks=callbacks, | ||
accelerator="gpu", | ||
devices=[device_id], | ||
logger=logger, | ||
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method | ||
) | ||
|
||
# Fit | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
import random | ||
import torch | ||
|
||
class Cluster(): | ||
|
||
""" | ||
Multiple gaussian distributed clusters, as in the Solomon benchmark dataset | ||
Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686) | ||
Args: | ||
n_cluster: Number of the gaussian distributed clusters | ||
""" | ||
def __init__(self, n_cluster: int = 3): | ||
super().__init__() | ||
self.lower, self.upper = 0.2, 0.8 | ||
self.std = 0.07 | ||
self.n_cluster = n_cluster | ||
def sample(self, size): | ||
|
||
batch_size, num_loc, _ = size | ||
|
||
# Generate the centers of the clusters | ||
center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster * 2) | ||
|
||
# Pre-define the coordinates | ||
coords = torch.zeros(batch_size, num_loc, 2) | ||
|
||
# Calculate the size of each cluster | ||
cluster_sizes = [num_loc // self.n_cluster] * self.n_cluster | ||
for i in range(num_loc % self.n_cluster): | ||
cluster_sizes[i] += 1 | ||
|
||
# Generate the coordinates | ||
current_index = 0 | ||
for i in range(self.n_cluster): | ||
means = center[:, i * 2:(i + 1) * 2] | ||
stds = torch.full((batch_size, 2), self.std) | ||
points = torch.normal(means.unsqueeze(1).expand(-1, cluster_sizes[i], -1), | ||
stds.unsqueeze(1).expand(-1, cluster_sizes[i], -1)) | ||
coords[:, current_index:current_index + cluster_sizes[i], :] = points | ||
current_index += cluster_sizes[i] | ||
|
||
# Confine the coordinates to range [0, 1] | ||
coords.clamp_(0, 1) | ||
|
||
return coords | ||
|
||
class Mixed(): | ||
|
||
""" | ||
50% nodes sampled from uniform distribution, 50% nodes sampled from gaussian distribution, as in the Solomon benchmark dataset | ||
Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686) | ||
Args: | ||
n_cluster_mix: Number of the gaussian distributed clusters | ||
""" | ||
|
||
def __init__(self, n_cluster_mix=1): | ||
super().__init__() | ||
self.lower, self.upper = 0.2, 0.8 | ||
self.std = 0.07 | ||
self.n_cluster_mix = n_cluster_mix | ||
def sample(self, size): | ||
|
||
batch_size, num_loc, _ = size | ||
|
||
# Generate the centers of the clusters | ||
center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster_mix * 2) | ||
|
||
# Pre-define the coordinates sampled under uniform distribution | ||
coords = torch.FloatTensor(batch_size, num_loc, 2).uniform_(0, 1) | ||
|
||
# Sample mutated index (default setting: 50% mutation) | ||
mutate_idx = torch.stack([torch.randperm(num_loc)[:num_loc // 2] for _ in range(batch_size)]) | ||
|
||
# Generate the coordinates | ||
segment_size = num_loc // (2 * self.n_cluster_mix) | ||
remaining_indices = num_loc // 2 - segment_size * (self.n_cluster_mix - 1) | ||
sizes = [segment_size] * (self.n_cluster_mix - 1) + [remaining_indices] | ||
for i in range(self.n_cluster_mix): | ||
indices = mutate_idx[:, sum(sizes[:i]):sum(sizes[:i + 1])] | ||
means_x = center[:, 2 * i].unsqueeze(1).expand(-1, sizes[i]) | ||
means_y = center[:, 2 * i + 1].unsqueeze(1).expand(-1, sizes[i]) | ||
coords.scatter_(1, indices.unsqueeze(-1).expand(-1, -1, 2), | ||
torch.stack([ | ||
torch.normal(means_x.expand(-1, sizes[i]), self.std), | ||
torch.normal(means_y.expand(-1, sizes[i]), self.std) | ||
], dim=2)) | ||
|
||
# Confine the coordinates to range [0, 1] | ||
coords.clamp_(0, 1) | ||
|
||
return coords | ||
|
||
class Gaussian_Mixture(): | ||
''' | ||
Following Zhou et al. (2023): https://arxiv.org/abs/2305.19587 | ||
Args: | ||
num_modes: the number of clusters/modes in the Gaussian Mixture. | ||
cdist: scale of the uniform distribution for center generation. | ||
''' | ||
def __init__(self, num_modes: int = 0, cdist: int = 0): | ||
super().__init__() | ||
self.num_modes = num_modes | ||
self.cdist = cdist | ||
|
||
def sample(self, size): | ||
|
||
batch_size, num_loc, _ = size | ||
|
||
if self.num_modes == 0: # (0, 0) - uniform | ||
return torch.rand((batch_size, num_loc, 2)) | ||
elif self.num_modes == 1 and self.cdist == 1: # (1, 1) - gaussian | ||
return self.generate_gaussian(batch_size, num_loc) | ||
else: | ||
res = [self.generate_gaussian_mixture(num_loc) for _ in range(batch_size)] | ||
return torch.stack(res) | ||
|
||
def generate_gaussian_mixture(self, num_loc): | ||
|
||
"""Following the setting in Zhang et al. 2022 (https://arxiv.org/abs/2204.03236)""" | ||
|
||
# Randomly decide how many points each mode gets | ||
nums = torch.multinomial(input=torch.ones(self.num_modes) / self.num_modes, num_samples=num_loc, replacement=True) | ||
|
||
# Prepare to collect points | ||
coords = torch.empty((0, 2)) | ||
|
||
# Generate points for each mode | ||
for i in range(self.num_modes): | ||
num = (nums == i).sum() # Number of points in this mode | ||
if num > 0: | ||
center = torch.rand((1, 2)) * self.cdist | ||
cov = torch.eye(2) # Covariance matrix | ||
nxy = torch.distributions.MultivariateNormal(center.squeeze(), covariance_matrix=cov).sample((num,)) | ||
coords = torch.cat((coords, nxy), dim=0) | ||
|
||
return self._global_min_max_scaling(coords) | ||
|
||
def generate_gaussian(self, batch_size, num_loc): | ||
|
||
"""Following the setting in Xin et al. 2022 (https://openreview.net/pdf?id=nJuzV-izmPJ)""" | ||
|
||
# Mean and random covariances | ||
mean = torch.full((batch_size, num_loc, 2), 0.5) | ||
covs = torch.rand(batch_size) # Random covariances between 0 and 1 | ||
|
||
# Generate the coordinates | ||
coords = torch.zeros((batch_size, num_loc, 2)) | ||
for i in range(batch_size): | ||
# Construct covariance matrix for each sample | ||
cov_matrix = torch.tensor([[1.0, covs[i]], [covs[i], 1.0]]) | ||
m = torch.distributions.MultivariateNormal(mean[i], covariance_matrix=cov_matrix) | ||
coords[i] = m.sample() | ||
|
||
# Shuffle the coordinates | ||
indices = torch.randperm(coords.size(0)) | ||
coords = coords[indices] | ||
|
||
return self._batch_normalize_and_center(coords) | ||
|
||
def _global_min_max_scaling(self, coords): | ||
|
||
# Scale the points to [0, 1] using min-max scaling | ||
coords_min = coords.min(0, keepdim=True).values | ||
coords_max = coords.max(0, keepdim=True).values | ||
coords = (coords - coords_min) / (coords_max - coords_min) | ||
|
||
return coords | ||
|
||
def _batch_normalize_and_center(self, coords): | ||
# Step 1: Compute min and max along each batch | ||
coords_min = coords.min(dim=1, keepdim=True).values | ||
coords_max = coords.max(dim=1, keepdim=True).values | ||
|
||
# Step 2: Normalize coordinates to range [0, 1] | ||
coords = coords - coords_min # Broadcasting subtracts min value on each coordinate | ||
range_max = (coords_max - coords_min).max(dim=-1, keepdim=True).values # The maximum range among both coordinates | ||
coords = coords / range_max # Divide by the max range to normalize | ||
|
||
# Step 3: Center the batch in the middle of the [0, 1] range | ||
coords = coords + (1 - coords.max(dim=1, keepdim=True).values) / 2 # Centering the batch | ||
|
||
return coords | ||
|
||
class Mix_Distribution(): | ||
|
||
''' | ||
Mixture of three exemplar distributions in batch-level, i.e. Uniform, Cluster, Mixed | ||
Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686) | ||
Args: | ||
n_cluster: Number of the gaussian distributed clusters in Cluster distribution | ||
n_cluster_mix: Number of the gaussian distributed clusters in Mixed distribution | ||
''' | ||
def __init__(self, n_cluster=3, n_cluster_mix=1): | ||
super().__init__() | ||
self.lower, self.upper = 0.2, 0.8 | ||
self.std = 0.07 | ||
self.Mixed = Mixed(n_cluster_mix=n_cluster_mix) | ||
self.Cluster = Cluster(n_cluster=n_cluster) | ||
|
||
def sample(self, size): | ||
|
||
batch_size, num_loc, _ = size | ||
|
||
# Pre-define the coordinates sampled under uniform distribution | ||
coords = torch.FloatTensor(batch_size, num_loc, 2).uniform_(0, 1) | ||
|
||
# Random sample probability for the distribution of each sample | ||
p = torch.rand(batch_size) | ||
|
||
# Mixed | ||
mask = p <= 0.33 | ||
n_mixed = mask.sum().item() | ||
if n_mixed > 0: | ||
coords[mask] = self.Mixed.sample((n_mixed, num_loc, 2)) | ||
|
||
# Cluster | ||
mask = (p > 0.33) & (p <= 0.66) | ||
n_cluster = mask.sum().item() | ||
if n_cluster > 0: | ||
coords[mask] = self.Cluster.sample((n_cluster, num_loc, 2)) | ||
|
||
# The remaining ones are uniformly distributed | ||
return coords | ||
|
||
class Mix_Multi_Distributions(): | ||
|
||
''' | ||
Mixture of 11 Gaussian-like distributions in batch-level | ||
Following the setting in Zhou et al. (2023): https://arxiv.org/abs/2305.19587 | ||
''' | ||
def __init__(self): | ||
super().__init__() | ||
self.dist_set = [(0, 0), (1, 1)] + [(m, c) for m in [3, 5, 7] for c in [10, 30, 50]] | ||
|
||
def sample(self, size): | ||
batch_size, num_loc, _ = size | ||
coords = torch.zeros(batch_size, num_loc, 2) | ||
|
||
# Pre-select distributions for the entire batch | ||
dists = [random.choice(self.dist_set) for _ in range(batch_size)] | ||
unique_dists = list(set(dists)) # Unique distributions to minimize re-instantiation | ||
|
||
# Instantiate Gaussian_Mixture only once per unique distribution | ||
gm_instances = {dist: Gaussian_Mixture(*dist) for dist in unique_dists} | ||
|
||
# Batch process where possible | ||
for i, dist in enumerate(dists): | ||
coords[i] = gm_instances[dist].sample((1, num_loc, 2)).squeeze(0) | ||
|
||
return coords |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.