Skip to content

Commit

Permalink
load guidingBiggan script
Browse files Browse the repository at this point in the history
  • Loading branch information
kidist-amde committed Jan 23, 2022
1 parent 6b29378 commit 9e6d693
Showing 1 changed file with 126 additions and 4 deletions.
130 changes: 126 additions & 4 deletions guiding_BigGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,61 @@
import math
import numpy as np
from tqdm import tqdm, trange


import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
import torchvision
from vgg_face import Vgg_face_dag
from torch import nn
from torch.utils.data import Dataset,DataLoader

# Import my stuff
import inception_utils
import utils
import losses
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
class ConditionalBigGAN(nn.Module):
def __init__(self,G,classifier,input_dim):
super().__init__()
self.G = G
self.classifier = classifier
self.output_dim = G.dim_z
self.input_generator = InputGenerator(input_dim,self.output_dim)
def forward(self,inputs):
onehot_inputs = nn.functional.one_hot(inputs).float().squeeze(1)
mu_sigma = self.input_generator(onehot_inputs)
mu = mu_sigma[:,:self.output_dim]
sigma = mu_sigma[:,self.output_dim:]
# to make sigma postive
sigma = nn.functional.softplus(sigma)
# generat random input z (epslon)
eps = torch.randn(*mu.shape)
z = mu+sigma * eps
# generat images
images = self.G(z,self.G.shared(inputs))
outputs = self.classifier(images)
return outputs
class InputGenerator(nn.Module):
def __init__(self,input_dim,output_dim):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.fc = nn.Sequential(
nn.Linear(input_dim,256) ,
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
# input Genretor for mu and std (*2)
nn.Linear(512,output_dim *2 )
)
# forward pass
def forward(self,inputs):
return self.fc(inputs)

def load_BigGAN_generator(config):
# Prepare state dict, which holds things like epoch # and itr #
state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
Expand Down Expand Up @@ -61,6 +100,9 @@ def load_BigGAN_generator(config):
strict=False, load_optim=False)
return G
def main():
batch_size = 64
learning_rate = 1e-3
epochs = 1000
#parse command line and run
parser = utils.prepare_parser()
parser = utils.add_sample_parser(parser)
Expand All @@ -78,7 +120,87 @@ def main():
'km.jpg' ,
nrow=int(G_batch_size**0.5),
normalize=True)

model = get_race_classifier()
cGAN = ConditionalBigGAN(G,model,5)
# generate input
inputs = torch.randint(0,5,size=(32,))
outputs = cGAN(inputs)
print(outputs.shape)
train_dataset = DummyDataset(20000)
validation_dataset = DummyDataset(5000)
train_loader = DataLoader(train_dataset,batch_size = batch_size,shuffle = True,drop_last = True)
val_loader = DataLoader(validation_dataset,batch_size = batch_size,shuffle = False,drop_last = False)
optimizer =torch.optim.SGD(cGAN.input_generator.parameters(),lr = learning_rate,momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
# traning loop
for epoch in range(epochs):
train_loss,train_acc = train_epoch(cGAN,train_loader,optimizer,criterion)
val_loss,val_acc = evaluate(cGAN,val_loader,criterion)
# log
print("Epoch:{}/{} Train_loss:{:.4f} Train_acc:{:.2f}%".format(epoch+1,epochs,train_loss,train_acc*100))
print("val_loss:{:.4f} val_acc:{:.2f}%".format(val_loss,val_acc*100))

def get_race_classifier():
model = Vgg_face_dag()
state_dict = torch.load("vgg_face_dag.pth")
model.load_state_dict(state_dict)
# freez the other layer
# for param in model.parameters():
# param.requires_grad = False
model.fc6 = torch.nn.Linear(in_features=512, out_features=512, bias=True)
model.fc7 = torch.nn.Linear(in_features=512, out_features=1024, bias=True)
model.fc8 = torch.nn.Linear(in_features=1024, out_features=5, bias=True)
return model
class DummyDataset:
def __init__(self,size):
self.size = size
def __len__(self):
return self.size
def __getitem__(self,index):
output = torch.randint(0,5,size=(1,))
return output,output

def train_epoch(model,train_loader,optimizer,criterion):
model.train()
total = 0
losses = 0
corrects = 0
for inputs,labels in tqdm(train_loader,total=len(train_loader)):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs,labels.squeeze(1))
# batch total loss
losses+=loss.item() * labels.size(0)
total+=labels.size(0)
predictions = outputs.argmax(dim = -1)
# number of correctly predicted class
corrects+=(predictions==labels.squeeze(1)).float().sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# return avg loss and acc
return losses/total , corrects/total

def evaluate(model,val_loader,criterion):
model.eval()
total = 0
losses = 0
corrects = 0
# to skip backpropagation step
with torch.no_grad():
for inputs,labels in tqdm(val_loader,total=len(val_loader)):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs,labels.squeeze(1))
losses+=loss.item() * labels.size(0)
total+=labels.size(0)
predictions = outputs.argmax(dim = -1)
# number of correctly predicted class
corrects+=(predictions==labels.squeeze(1)).float().sum()
# return avg loss and acc
return losses/total , corrects/total

if __name__ == '__main__':
main()
Expand Down

0 comments on commit 9e6d693

Please sign in to comment.