Skip to content

Commit

Permalink
init repository
Browse files Browse the repository at this point in the history
  • Loading branch information
raven38 committed Mar 15, 2021
0 parents commit 80e15f3
Show file tree
Hide file tree
Showing 11 changed files with 1,249 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "sefa"]
path = sefa
url = https://github.com/genforce/sefa.git
44 changes: 44 additions & 0 deletions align_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import sys
import bz2
from tensorflow.keras.utils import get_file
from face_alignment import image_align
from landmarks_detector import LandmarksDetector

LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'


def unpack_bz2(src_path):
data = bz2.BZ2File(src_path).read()
dst_path = src_path[:-4]
with open(dst_path, 'wb') as fp:
fp.write(data)
return dst_path


def align_image(input_path, output_path):
landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2',
LANDMARKS_MODEL_URL, cache_subdir='temp'))
landmarks_detector = LandmarksDetector(landmarks_model_path)
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(input_path), start=1):
image_align(input_path, output_path, face_landmarks)

if __name__ == "__main__":
"""
Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
python align_images.py /raw_images /aligned_images
"""

landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2',
LANDMARKS_MODEL_URL, cache_subdir='temp'))
RAW_IMAGES_DIR = sys.argv[1]
ALIGNED_IMAGES_DIR = sys.argv[2]

landmarks_detector = LandmarksDetector(landmarks_model_path)
for img_name in os.listdir(RAW_IMAGES_DIR):
raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)

image_align(raw_img_path, aligned_face_path, face_landmarks)
148 changes: 148 additions & 0 deletions encode_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import matplotlib.pyplot as plt
from read_image import image_reader
import argparse
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from torchvision.utils import save_image
from perceptual_model import VGG16_for_Perceptual
import torch.optim as optim

from sefa.models import parse_gan_type
from utils import load_generator, to_tensor, parse_gan_type, postprocess

device = "cuda:0" if torch.cuda.is_available() else "cpu"


def parse_resolution(model_name):
return int(''.join(filter(str.isdigit, model_name)))


def forward(model, gan_type, code):
if gan_type == 'pggan':
image = model(code)['image']
elif gan_type in ['stylegan', 'stylegan2']:
image = model.synthesis(code)['image']
return image


def optimize_style(source_image, model, model_name, gan_type, dlatent, iteration, pb):
resolution = parse_resolution(model_name)

img = image_reader(source_image, resize=resolution) # (1,3,1024,1024) -1~1
img = img.to(device)

MSE_Loss = nn.MSELoss(reduction="mean")

img_p = img.clone() # Perceptual loss 用画像
upsample2d = torch.nn.Upsample(
scale_factor=256 / resolution, mode="bilinear"
) # VGG入力のため(256,256)にリサイズ
img_p = upsample2d(img_p)

perceptual_net = VGG16_for_Perceptual(n_layers=[2, 4, 14, 21]).to(device)
w = to_tensor(dlatent).requires_grad_()
optimizer = optim.Adam({w}, lr=0.01, betas=(0.9, 0.999), eps=1e-8)

for i in range(iteration):
pb.progress(i / iteration)
optimizer.zero_grad()
synth_img = forward(model, gan_type, w)
synth_img = (synth_img + 1.0) / 2.0
mse_loss, perceptual_loss = caluclate_loss(
synth_img, img, perceptual_net, img_p, MSE_Loss, upsample2d
)
loss = mse_loss + perceptual_loss
loss.backward()
optimizer.step()

return w.detach().cpu().numpy()


def main():
parser = argparse.ArgumentParser(
description="Find latent representation of reference images using perceptual loss"
)

parser.add_argument("--src_im", default="sample.png")
parser.add_argument("--src_dir", default="source_image/")

iteration = 1000
args = parser.parse_args()

model_name = 'stylegan_ffhq1024'
model = load_generator(model_name)
resolution = parse_resolution(model_name)
gan_type = parse_gan_type(model)

name = args.src_im.split(".")[0]
img = image_reader(args.src_dir + args.src_im, resize=resolution) # (1,3,1024,1024) -1~1
img = img.to(device)

MSE_Loss = nn.MSELoss(reduction="mean")

img_p = img.clone() # Perceptual loss 用画像
upsample2d = torch.nn.Upsample(
scale_factor=256 / resolution, mode="bilinear"
) # VGG入力のため(256,256)にリサイズ
img_p = upsample2d(img_p)

perceptual_net = VGG16_for_Perceptual(n_layers=[2, 4, 14, 21]).to(device)
# dlatent = torch.randn(1, model.z_space_dim, requires_grad=True, device=device)
w = to_tensor(sample(model, gan_type)).requires_grad_()
optimizer = optim.Adam({w}, lr=0.01, betas=(0.9, 0.999), eps=1e-8)
# optimizer = optim.SGD({dlatent}, lr=1.) #, momentum=0.9, nesterov=True)

print("Start")
loss_list = []
for i in range(iteration):
optimizer.zero_grad()

synth_img = forward(model, gan_type, w)
synth_img = (synth_img + 1.0) / 2.0
mse_loss, perceptual_loss = caluclate_loss(
synth_img, img, perceptual_net, img_p, MSE_Loss, upsample2d
)
loss = mse_loss + perceptual_loss
loss.backward()

optimizer.step()

loss_np = loss.detach().cpu().numpy()
loss_p = perceptual_loss.detach().cpu().numpy()
loss_m = mse_loss.detach().cpu().numpy()

loss_list.append(loss_np)
if i % 10 == 0:
print(
"iter{}: loss -- {}, mse_loss --{}, percep_loss --{}".format(
i, loss_np, loss_m, loss_p
)
)
save_image(synth_img.clamp(0, 1), "save_image/encode1/{}.png".format(i))
# np.save("loss_list.npy",loss_list)
np.save("latent_W/{}.npy".format(name), w.detach().cpu().numpy())


def caluclate_loss(synth_img, img, perceptual_net, img_p, MSE_Loss, upsample2d):
# calculate MSE Loss
mse_loss = MSE_Loss(synth_img, img) # (lamda_mse/N)*||G(w)-I||^2

# calculate Perceptual Loss
real_0, real_1, real_2, real_3 = perceptual_net(img_p)
synth_p = upsample2d(synth_img) # (1,3,256,256)
synth_0, synth_1, synth_2, synth_3 = perceptual_net(synth_p)

perceptual_loss = 0
perceptual_loss += MSE_Loss(synth_0, real_0)
perceptual_loss += MSE_Loss(synth_1, real_1)
perceptual_loss += MSE_Loss(synth_2, real_2)
perceptual_loss += MSE_Loss(synth_3, real_3)

return mse_loss, perceptual_loss


if __name__ == "__main__":
main()
84 changes: 84 additions & 0 deletions face_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import scipy.ndimage
import os
import PIL.Image


def image_align(src_file, dst_file, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True):
# Align function from FFHQ dataset pre-processing step
# https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py

lm = np.array(face_landmarks)
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise

# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2

# Load in-the-wild image.
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
return
img = PIL.Image.open(src_file)

# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink

# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]

# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]

# Transform.
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

# Save aligned image.
img.save(dst_file, 'PNG')
Loading

0 comments on commit 80e15f3

Please sign in to comment.