Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/refresh'
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsandberg committed Apr 1, 2018
2 parents fc3b4c6 + 2413277 commit a7590ff
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 415 deletions.
175 changes: 113 additions & 62 deletions src/facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import os
from subprocess import Popen, PIPE
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
from scipy import misc
from sklearn.model_selection import KFold
Expand All @@ -39,6 +38,7 @@
import random
import re
from tensorflow.python.platform import gfile
import math
from six import iteritems

def triplet_loss(anchor, positive, negative, alpha):
Expand All @@ -61,19 +61,6 @@ def triplet_loss(anchor, positive, negative, alpha):

return loss

def decov_loss(xs):
"""Decov loss as described in https://arxiv.org/pdf/1511.06068.pdf
'Reducing Overfitting In Deep Networks by Decorrelating Representation'
"""
x = tf.reshape(xs, [int(xs.get_shape()[0]), -1])
m = tf.reduce_mean(x, 0, True)
z = tf.expand_dims(x-m, 2)
corr = tf.reduce_mean(tf.matmul(z, tf.transpose(z, perm=[0,2,1])), 0)
corr_frob_sqr = tf.reduce_sum(tf.square(corr))
corr_diag_sqr = tf.reduce_sum(tf.square(tf.diag_part(corr)))
loss = 0.5*(corr_frob_sqr - corr_diag_sqr)
return loss

def center_loss(features, label, alfa, nrof_classes):
"""Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
(http://ydwen.github.io/papers/WenECCV16.pdf)
Expand All @@ -85,7 +72,8 @@ def center_loss(features, label, alfa, nrof_classes):
centers_batch = tf.gather(centers, label)
diff = (1 - alfa) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.reduce_mean(tf.square(features - centers_batch))
with tf.control_dependencies([centers]):
loss = tf.reduce_mean(tf.square(features - centers_batch))
return loss, centers

def get_image_paths_and_labels(dataset):
Expand Down Expand Up @@ -118,38 +106,75 @@ def random_rotate_image(image):
angle = np.random.uniform(low=-10.0, high=10.0)
return misc.imrotate(image, angle, 'bicubic')

def read_and_augment_data(image_list, label_list, image_size, batch_size, max_nrof_epochs,
random_crop, random_flip, random_rotate, nrof_preprocess_threads, shuffle=True):

images = ops.convert_to_tensor(image_list, dtype=tf.string)
labels = ops.convert_to_tensor(label_list, dtype=tf.int32)

# Makes an input queue
input_queue = tf.train.slice_input_producer([images, labels],
num_epochs=max_nrof_epochs, shuffle=shuffle)

images_and_labels = []
# def read_and_augment_data(image_list, label_list, image_size, batch_size, max_nrof_epochs,
# random_crop, random_flip, random_rotate, nrof_preprocess_threads, shuffle=True):
#
# images = ops.convert_to_tensor(image_list, dtype=tf.string)
# labels = ops.convert_to_tensor(label_list, dtype=tf.int32)
#
# # Makes an input queue
# input_queue = tf.train.slice_input_producer([images, labels],
# num_epochs=max_nrof_epochs, shuffle=shuffle)
#
# images_and_labels = []
# for _ in range(nrof_preprocess_threads):
# image, label = read_images_from_disk(input_queue)
# if random_rotate:
# image = tf.py_func(random_rotate_image, [image], tf.uint8)
# if random_crop:
# image = tf.random_crop(image, [image_size, image_size, 3])
# else:
# image = tf.image.resize_image_with_crop_or_pad(image, image_size, image_size)
# if random_flip:
# image = tf.image.random_flip_left_right(image)
# #pylint: disable=no-member
# image.set_shape((image_size, image_size, 3))
# image = tf.image.per_image_standardization(image)
# images_and_labels.append([image, label])
#
# image_batch, label_batch = tf.train.batch_join(
# images_and_labels, batch_size=batch_size,
# capacity=4 * nrof_preprocess_threads * batch_size,
# allow_smaller_final_batch=True)
#
# return image_batch, label_batch

# 1: Random rotate 2: Random crop 4: Random flip 8: Fixed image standardization 16: Flip
RANDOM_ROTATE = 1
RANDOM_CROP = 2
RANDOM_FLIP = 4
FIXED_STANDARDIZATION = 8
FLIP = 16
def create_input_pipeline(images_and_labels_list, input_queue, image_size, nrof_preprocess_threads):
for _ in range(nrof_preprocess_threads):
image, label = read_images_from_disk(input_queue)
if random_rotate:
image = tf.py_func(random_rotate_image, [image], tf.uint8)
if random_crop:
image = tf.random_crop(image, [image_size, image_size, 3])
else:
image = tf.image.resize_image_with_crop_or_pad(image, image_size, image_size)
if random_flip:
image = tf.image.random_flip_left_right(image)
#pylint: disable=no-member
image.set_shape((image_size, image_size, 3))
image = tf.image.per_image_standardization(image)
images_and_labels.append([image, label])
filenames, label, control = input_queue.dequeue()
images = []
for filename in tf.unstack(filenames):
file_contents = tf.read_file(filename)
image = tf.image.decode_image(file_contents, 3)
image = tf.cond(get_control_flag(control[0], RANDOM_ROTATE),
lambda:tf.py_func(random_rotate_image, [image], tf.uint8),
lambda:tf.identity(image))
image = tf.cond(get_control_flag(control[0], RANDOM_CROP),
lambda:tf.random_crop(image, image_size + (3,)),
lambda:tf.image.resize_image_with_crop_or_pad(image, image_size[0], image_size[1]))
image = tf.cond(get_control_flag(control[0], RANDOM_FLIP),
lambda:tf.image.random_flip_left_right(image),
lambda:tf.identity(image))
image = tf.cond(get_control_flag(control[0], FIXED_STANDARDIZATION),
lambda:(tf.cast(image, tf.float32) - 127.5)/128.0,
lambda:tf.image.per_image_standardization(image))
image = tf.cond(get_control_flag(control[0], FLIP),
lambda:tf.image.flip_left_right(image),
lambda:tf.identity(image))
#pylint: disable=no-member
image.set_shape(image_size + (3,))
images.append(image)
images_and_labels_list.append([images, label])
return images_and_labels_list

image_batch, label_batch = tf.train.batch_join(
images_and_labels, batch_size=batch_size,
capacity=4 * nrof_preprocess_threads * batch_size,
allow_smaller_final_batch=True)

return image_batch, label_batch
def get_control_flag(control, field):
return tf.equal(tf.mod(tf.floor_div(control, field), 2), 1)

def _add_loss_summaries(total_loss):
"""Add summaries for losses.
Expand Down Expand Up @@ -305,7 +330,10 @@ def get_learning_rate_from_file(filename, epoch):
if line:
par = line.strip().split(':')
e = int(par[0])
lr = float(par[1])
if par[1]=='-':
lr = -1
else:
lr = float(par[1])
if e <= epoch:
learning_rate = lr
else:
Expand Down Expand Up @@ -345,26 +373,27 @@ def get_image_paths(facedir):
image_paths = [os.path.join(facedir,img) for img in images]
return image_paths

def split_dataset(dataset, split_ratio, mode):
def split_dataset(dataset, split_ratio, min_nrof_images_per_class, mode):
if mode=='SPLIT_CLASSES':
nrof_classes = len(dataset)
class_indices = np.arange(nrof_classes)
np.random.shuffle(class_indices)
split = int(round(nrof_classes*split_ratio))
split = int(round(nrof_classes*(1-split_ratio)))
train_set = [dataset[i] for i in class_indices[0:split]]
test_set = [dataset[i] for i in class_indices[split:-1]]
elif mode=='SPLIT_IMAGES':
train_set = []
test_set = []
min_nrof_images = 2
for cls in dataset:
paths = cls.image_paths
np.random.shuffle(paths)
split = int(round(len(paths)*split_ratio))
if split<min_nrof_images:
continue # Not enough images for test set. Skip class...
train_set.append(ImageClass(cls.name, paths[0:split]))
test_set.append(ImageClass(cls.name, paths[split:-1]))
nrof_images_in_class = len(paths)
split = int(math.floor(nrof_images_in_class*(1-split_ratio)))
if split==nrof_images_in_class:
split = nrof_images_in_class-1
if split>=min_nrof_images_per_class and nrof_images_in_class-split>=1:
train_set.append(ImageClass(cls.name, paths[:split]))
test_set.append(ImageClass(cls.name, paths[split:]))
else:
raise ValueError('Invalid train/test split mode "%s"' % mode)
return train_set, test_set
Expand Down Expand Up @@ -412,8 +441,24 @@ def get_model_filenames(model_dir):
max_step = step
ckpt_file = step_str.groups()[0]
return meta_file, ckpt_file

def distance(embeddings1, embeddings2, distance_metric=0):
if distance_metric==0:
# Euclidian distance
diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
elif distance_metric==1:
# Distance based on cosine similarity
dot = np.sum(np.multiply(embeddings1, embeddings2), axis=1)
norm = np.linalg.norm(embeddings1, axis=1) * np.linalg.norm(embeddings2, axis=1)
similarity = dot / norm
dist = np.arccos(similarity) / math.pi
else:
raise 'Undefined distance metric %d' % distance_metric

return dist

def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10):
def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):
assert(embeddings1.shape[0] == embeddings2.shape[0])
assert(embeddings1.shape[1] == embeddings2.shape[1])
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
Expand All @@ -424,11 +469,14 @@ def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_fold
fprs = np.zeros((nrof_folds,nrof_thresholds))
accuracy = np.zeros((nrof_folds))

diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
indices = np.arange(nrof_pairs)

for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
if subtract_mean:
mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)
else:
mean = 0.0
dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)

# Find the best threshold for the fold
acc_train = np.zeros((nrof_thresholds))
Expand All @@ -439,8 +487,8 @@ def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_fold
tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])
_, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])

tpr = np.mean(tprs,0)
fpr = np.mean(fprs,0)
tpr = np.mean(tprs,0)
fpr = np.mean(fprs,0)
return tpr, fpr, accuracy

def calculate_accuracy(threshold, dist, actual_issame):
Expand All @@ -457,7 +505,7 @@ def calculate_accuracy(threshold, dist, actual_issame):



def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10):
def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10, distance_metric=0, subtract_mean=False):
assert(embeddings1.shape[0] == embeddings2.shape[0])
assert(embeddings1.shape[1] == embeddings2.shape[1])
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
Expand All @@ -467,11 +515,14 @@ def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_targe
val = np.zeros(nrof_folds)
far = np.zeros(nrof_folds)

diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
indices = np.arange(nrof_pairs)

for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
if subtract_mean:
mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)
else:
mean = 0.0
dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)

# Find the threshold that gives FAR = far_target
far_train = np.zeros(nrof_thresholds)
Expand Down
Loading

0 comments on commit a7590ff

Please sign in to comment.