diff --git a/guiding_BigGAN.py b/guiding_BigGAN.py
index 808c5095..4c743ea1 100644
--- a/guiding_BigGAN.py
+++ b/guiding_BigGAN.py
@@ -2,8 +2,6 @@
import math
import numpy as np
from tqdm import tqdm, trange
-
-
import torch
import torch.nn as nn
from torch.nn import init
@@ -11,13 +9,54 @@
import torch.nn.functional as F
from torch.nn import Parameter as P
import torchvision
+from vgg_face import Vgg_face_dag
+from torch import nn
+from torch.utils.data import Dataset,DataLoader
# Import my stuff
import inception_utils
import utils
import losses
-device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+class ConditionalBigGAN(nn.Module):
+ def __init__(self,G,classifier,input_dim):
+ super().__init__()
+ self.G = G
+ self.classifier = classifier
+ self.output_dim = G.dim_z
+ self.input_generator = InputGenerator(input_dim,self.output_dim)
+ def forward(self,inputs):
+ onehot_inputs = nn.functional.one_hot(inputs).float().squeeze(1)
+ mu_sigma = self.input_generator(onehot_inputs)
+ mu = mu_sigma[:,:self.output_dim]
+ sigma = mu_sigma[:,self.output_dim:]
+ # to make sigma postive
+ sigma = nn.functional.softplus(sigma)
+ # generat random input z (epslon)
+ eps = torch.randn(*mu.shape)
+ z = mu+sigma * eps
+ # generat images
+ images = self.G(z,self.G.shared(inputs))
+ outputs = self.classifier(images)
+ return outputs
+class InputGenerator(nn.Module):
+ def __init__(self,input_dim,output_dim):
+ super().__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.fc = nn.Sequential(
+ nn.Linear(input_dim,256) ,
+ nn.ReLU(),
+ nn.Linear(256,512),
+ nn.ReLU(),
+ # input Genretor for mu and std (*2)
+ nn.Linear(512,output_dim *2 )
+ )
+ # forward pass
+ def forward(self,inputs):
+ return self.fc(inputs)
+
def load_BigGAN_generator(config):
# Prepare state dict, which holds things like epoch # and itr #
state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
@@ -61,6 +100,9 @@ def load_BigGAN_generator(config):
strict=False, load_optim=False)
return G
def main():
+ batch_size = 64
+ learning_rate = 1e-3
+ epochs = 1000
#parse command line and run
parser = utils.prepare_parser()
parser = utils.add_sample_parser(parser)
@@ -78,7 +120,87 @@ def main():
'km.jpg' ,
nrow=int(G_batch_size**0.5),
normalize=True)
-
+ model = get_race_classifier()
+ cGAN = ConditionalBigGAN(G,model,5)
+ # generate input
+ inputs = torch.randint(0,5,size=(32,))
+ outputs = cGAN(inputs)
+ print(outputs.shape)
+ train_dataset = DummyDataset(20000)
+ validation_dataset = DummyDataset(5000)
+ train_loader = DataLoader(train_dataset,batch_size = batch_size,shuffle = True,drop_last = True)
+ val_loader = DataLoader(validation_dataset,batch_size = batch_size,shuffle = False,drop_last = False)
+ optimizer =torch.optim.SGD(cGAN.input_generator.parameters(),lr = learning_rate,momentum=0.9)
+ criterion = torch.nn.CrossEntropyLoss()
+ # traning loop
+ for epoch in range(epochs):
+ train_loss,train_acc = train_epoch(cGAN,train_loader,optimizer,criterion)
+ val_loss,val_acc = evaluate(cGAN,val_loader,criterion)
+ # log
+ print("Epoch:{}/{} Train_loss:{:.4f} Train_acc:{:.2f}%".format(epoch+1,epochs,train_loss,train_acc*100))
+ print("val_loss:{:.4f} val_acc:{:.2f}%".format(val_loss,val_acc*100))
+
+def get_race_classifier():
+ model = Vgg_face_dag()
+ state_dict = torch.load("vgg_face_dag.pth")
+ model.load_state_dict(state_dict)
+ # freez the other layer
+ # for param in model.parameters():
+ # param.requires_grad = False
+ model.fc6 = torch.nn.Linear(in_features=512, out_features=512, bias=True)
+ model.fc7 = torch.nn.Linear(in_features=512, out_features=1024, bias=True)
+ model.fc8 = torch.nn.Linear(in_features=1024, out_features=5, bias=True)
+ return model
+class DummyDataset:
+ def __init__(self,size):
+ self.size = size
+ def __len__(self):
+ return self.size
+ def __getitem__(self,index):
+ output = torch.randint(0,5,size=(1,))
+ return output,output
+
+def train_epoch(model,train_loader,optimizer,criterion):
+ model.train()
+ total = 0
+ losses = 0
+ corrects = 0
+ for inputs,labels in tqdm(train_loader,total=len(train_loader)):
+ inputs = inputs.to(device)
+ labels = labels.to(device)
+ outputs = model(inputs)
+ loss = criterion(outputs,labels.squeeze(1))
+ # batch total loss
+ losses+=loss.item() * labels.size(0)
+ total+=labels.size(0)
+ predictions = outputs.argmax(dim = -1)
+ # number of correctly predicted class
+ corrects+=(predictions==labels.squeeze(1)).float().sum()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ # return avg loss and acc
+ return losses/total , corrects/total
+
+def evaluate(model,val_loader,criterion):
+ model.eval()
+ total = 0
+ losses = 0
+ corrects = 0
+ # to skip backpropagation step
+ with torch.no_grad():
+ for inputs,labels in tqdm(val_loader,total=len(val_loader)):
+ inputs = inputs.to(device)
+ labels = labels.to(device)
+ outputs = model(inputs)
+ loss = criterion(outputs,labels.squeeze(1))
+ losses+=loss.item() * labels.size(0)
+ total+=labels.size(0)
+ predictions = outputs.argmax(dim = -1)
+ # number of correctly predicted class
+ corrects+=(predictions==labels.squeeze(1)).float().sum()
+ # return avg loss and acc
+ return losses/total , corrects/total
if __name__ == '__main__':
main()