diff --git a/guiding_BigGAN.py b/guiding_BigGAN.py index 808c5095..4c743ea1 100644 --- a/guiding_BigGAN.py +++ b/guiding_BigGAN.py @@ -2,8 +2,6 @@ import math import numpy as np from tqdm import tqdm, trange - - import torch import torch.nn as nn from torch.nn import init @@ -11,13 +9,54 @@ import torch.nn.functional as F from torch.nn import Parameter as P import torchvision +from vgg_face import Vgg_face_dag +from torch import nn +from torch.utils.data import Dataset,DataLoader # Import my stuff import inception_utils import utils import losses -device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +class ConditionalBigGAN(nn.Module): + def __init__(self,G,classifier,input_dim): + super().__init__() + self.G = G + self.classifier = classifier + self.output_dim = G.dim_z + self.input_generator = InputGenerator(input_dim,self.output_dim) + def forward(self,inputs): + onehot_inputs = nn.functional.one_hot(inputs).float().squeeze(1) + mu_sigma = self.input_generator(onehot_inputs) + mu = mu_sigma[:,:self.output_dim] + sigma = mu_sigma[:,self.output_dim:] + # to make sigma postive + sigma = nn.functional.softplus(sigma) + # generat random input z (epslon) + eps = torch.randn(*mu.shape) + z = mu+sigma * eps + # generat images + images = self.G(z,self.G.shared(inputs)) + outputs = self.classifier(images) + return outputs +class InputGenerator(nn.Module): + def __init__(self,input_dim,output_dim): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.fc = nn.Sequential( + nn.Linear(input_dim,256) , + nn.ReLU(), + nn.Linear(256,512), + nn.ReLU(), + # input Genretor for mu and std (*2) + nn.Linear(512,output_dim *2 ) + ) + # forward pass + def forward(self,inputs): + return self.fc(inputs) + def load_BigGAN_generator(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, @@ -61,6 +100,9 @@ def load_BigGAN_generator(config): strict=False, load_optim=False) return G def main(): + batch_size = 64 + learning_rate = 1e-3 + epochs = 1000 #parse command line and run parser = utils.prepare_parser() parser = utils.add_sample_parser(parser) @@ -78,7 +120,87 @@ def main(): 'km.jpg' , nrow=int(G_batch_size**0.5), normalize=True) - + model = get_race_classifier() + cGAN = ConditionalBigGAN(G,model,5) + # generate input + inputs = torch.randint(0,5,size=(32,)) + outputs = cGAN(inputs) + print(outputs.shape) + train_dataset = DummyDataset(20000) + validation_dataset = DummyDataset(5000) + train_loader = DataLoader(train_dataset,batch_size = batch_size,shuffle = True,drop_last = True) + val_loader = DataLoader(validation_dataset,batch_size = batch_size,shuffle = False,drop_last = False) + optimizer =torch.optim.SGD(cGAN.input_generator.parameters(),lr = learning_rate,momentum=0.9) + criterion = torch.nn.CrossEntropyLoss() + # traning loop + for epoch in range(epochs): + train_loss,train_acc = train_epoch(cGAN,train_loader,optimizer,criterion) + val_loss,val_acc = evaluate(cGAN,val_loader,criterion) + # log + print("Epoch:{}/{} Train_loss:{:.4f} Train_acc:{:.2f}%".format(epoch+1,epochs,train_loss,train_acc*100)) + print("val_loss:{:.4f} val_acc:{:.2f}%".format(val_loss,val_acc*100)) + +def get_race_classifier(): + model = Vgg_face_dag() + state_dict = torch.load("vgg_face_dag.pth") + model.load_state_dict(state_dict) + # freez the other layer + # for param in model.parameters(): + # param.requires_grad = False + model.fc6 = torch.nn.Linear(in_features=512, out_features=512, bias=True) + model.fc7 = torch.nn.Linear(in_features=512, out_features=1024, bias=True) + model.fc8 = torch.nn.Linear(in_features=1024, out_features=5, bias=True) + return model +class DummyDataset: + def __init__(self,size): + self.size = size + def __len__(self): + return self.size + def __getitem__(self,index): + output = torch.randint(0,5,size=(1,)) + return output,output + +def train_epoch(model,train_loader,optimizer,criterion): + model.train() + total = 0 + losses = 0 + corrects = 0 + for inputs,labels in tqdm(train_loader,total=len(train_loader)): + inputs = inputs.to(device) + labels = labels.to(device) + outputs = model(inputs) + loss = criterion(outputs,labels.squeeze(1)) + # batch total loss + losses+=loss.item() * labels.size(0) + total+=labels.size(0) + predictions = outputs.argmax(dim = -1) + # number of correctly predicted class + corrects+=(predictions==labels.squeeze(1)).float().sum() + optimizer.zero_grad() + loss.backward() + optimizer.step() + # return avg loss and acc + return losses/total , corrects/total + +def evaluate(model,val_loader,criterion): + model.eval() + total = 0 + losses = 0 + corrects = 0 + # to skip backpropagation step + with torch.no_grad(): + for inputs,labels in tqdm(val_loader,total=len(val_loader)): + inputs = inputs.to(device) + labels = labels.to(device) + outputs = model(inputs) + loss = criterion(outputs,labels.squeeze(1)) + losses+=loss.item() * labels.size(0) + total+=labels.size(0) + predictions = outputs.argmax(dim = -1) + # number of correctly predicted class + corrects+=(predictions==labels.squeeze(1)).float().sum() + # return avg loss and acc + return losses/total , corrects/total if __name__ == '__main__': main()