Skip to content

Commit

Permalink
Working on perturbation approaches
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielMckenzie committed Jun 21, 2024
1 parent 23484fd commit b3072cc
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/warcraft/PyEPO-warcraft-benchmarks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down
23 changes: 6 additions & 17 deletions src/warcraft/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pyepo.model.grb.grbmodel import optGrbModel
from torchvision.models import resnet18
from src.shortest_path.shortest_path_utils import shortestPathModel_8
import gurobipy as gp
from gurobipy import GRB

class WarcraftShortestPathNet(DYS_opt_net):
def __init__(self, grid_size, A, b, device='mps'):
Expand Down Expand Up @@ -281,15 +283,9 @@ def solve(self):
sol = sol.reshape(-1)
return sol, self._model.objVal


# init model
k = 12
grid = (k, k)
optmodel = shortestPathModel(grid)

# Model for perturbation-based approaches
class WarcraftShortestPathNet(nn.Module):
def __init__(self, grid_size, A, b, device='mps'):
class Pert_WarcraftShortestPathNet(nn.Module):
def __init__(self, grid_size, device='mps'):
super().__init__()
self.device = device
# These layers are like resnet18
Expand All @@ -304,7 +300,7 @@ def __init__(self, grid_size, A, b, device='mps'):
# max pooling
self.maxpool2 = nn.AdaptiveMaxPool2d((grid_size, grid_size))\
# Optimization layer. Can be used within test_time_forward
self.shortest_path_solver = shortestPathModel((12, 12))
self.shortest_path_solver = shortestPathModel((grid_size, grid_size))

def _data_space_forward(self, d):
h = self.conv1(d)
Expand All @@ -317,14 +313,7 @@ def _data_space_forward(self, d):
# reshape for optmodel
out = torch.squeeze(out, 1)
cost_vec = out.reshape(out.shape[0], -1)
if self.training:
batch_size = cost_vec.shape[0]
train_cost_vec = torch.zeros((batch_size, len(self.shortest_path_solver.edges)),device=self.device)
for e, edge in enumerate(self.shortest_path_solver.edges):
train_cost_vec[:,e] = cost_vec[:,edge[1]]
return train_cost_vec
else:
return cost_vec
return cost_vec

# Put it all together
def forward(self, d):
Expand Down
9 changes: 6 additions & 3 deletions src/warcraft/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.warcraft import trainer
from src.warcraft.models import WarcraftShortestPathNet, Cvx_WarcraftShortestPathNet
from src.warcraft.models import WarcraftShortestPathNet, Cvx_WarcraftShortestPathNet, Pert_WarcraftShortestPathNet
import argparse
import os
import dill
Expand Down Expand Up @@ -32,13 +32,16 @@ def main(args):
if args.model_type == "DYS":
net = WarcraftShortestPathNet(args.grid_size, A, b, args.device)
elif args.model_type == "CVX":
net = Cvx_WarcraftShortestPathNet(args.grid_size, A, b, args.device)
net = Cvx_WarcraftShortestPathNet(args.grid_size, A, b, args.device)
elif args.model_type == "PertOpt" or args.model_type == "BB":
net = Pert_WarcraftShortestPathNet(args.grid_size)
else:
print('\n Other models not implemented yet, sorry. \n')
print('\n Please choose an allowed model. \n')
return

net.to(args.device)


# Train!
print('\n---- Model type= ' + args.model_type + ' Grid size = ' + str(args.grid_size) + '---\n')
results = trainer.trainer(net, dataset_train, dataset_test, dataset_val, edges, args.grid_size, args.max_time, args.max_epochs, args.learning_rate, args.model_type, args.weights_dir, args.device)
Expand Down
7 changes: 5 additions & 2 deletions src/warcraft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from src.shortest_path.shortest_path_utils import convert_to_grid_torch, evaluate
from src.utils.accuracy import accuracy
from src.utils.evaluate import evaluate
from src.warcraft.utils import hammingLoss
import tqdm
import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -31,13 +32,15 @@ def trainer(net, train_dataset, test_dataset, val_dataset, edges, grid_size, max
# Initialize loss and evaluation metric
if model_type == "DYS" or model_type == "CVX":
criterion = nn.MSELoss()
elif model_type == "BBOpt" or model_type == "PertOpt":
elif model_type == "BBOpt":
criterion = hammingLoss()
elif model_type == "PertOpt":
criterion = nn.L1Loss()

metric = pyepo.metric.regret

if model_type == "BBOpt":
dbb = pyepo.func.blackboxOpt(net.shortest_path_solver, lambd=5, processes=1)
dbb = pyepo.func.blackboxOpt(net.shortest_path_solver, lambd=10, processes=1)
elif model_type == "PertOpt":
ptb = pyepo.func.perturbedOpt(net.shortest_path_solver, n_samples=3, sigma=1.0, processes=2)
elif model_type == "DYS" or model_type == "CVX":
Expand Down
8 changes: 7 additions & 1 deletion src/warcraft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import os
import time

Expand Down Expand Up @@ -82,4 +83,9 @@ def evaluate(nnet, optmodel, dataloader):
print("Avg Rel Regret: {:.2f}%".format(df["Relative Regret"].mean()*100))
print("Path Accuracy: {:.2f}%".format(df["Accuracy"].mean()*100))
print("Optimality Ratio: {:.2f}%".format(df["Optimal"].mean()*100))
return df
return df

class hammingLoss(nn.Module):
def forward(self, wp, w):
loss = wp * (1.0 - w) + (1.0 - wp) * w
return loss.mean(dim=0).sum()

0 comments on commit b3072cc

Please sign in to comment.