Skip to content

Commit

Permalink
resolved conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
visionjo committed Sep 11, 2019
2 parents 072c000 + d5d57c0 commit 9fb2780
Show file tree
Hide file tree
Showing 53 changed files with 774 additions and 558 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ By downloading the image data you agree to the following terms:


## Authors
* **Joseph Robinson** - [Github](https://github.com/huskyjo) - [web](http://www.jrobsvision.com)
* **Joseph Robinson** - [Github](https://github.com/visionjo) - [web](http://www.jrobsvision.com)

######
### Contact
Expand Down
Empty file removed data/.gitkeep
Empty file.
Empty file removed data/external/.gitkeep
Empty file.
Empty file removed data/interim/.gitkeep
Empty file.
Empty file removed data/processed/.gitkeep
Empty file.
Empty file removed data/raw/.gitkeep
Empty file.
120 changes: 0 additions & 120 deletions notebooks/bin_members.ipynb

This file was deleted.

277 changes: 0 additions & 277 deletions notebooks/explore_fiw.ipynb

This file was deleted.

Binary file removed references/.DS_Store
Binary file not shown.
Binary file removed references/ACM-MM-poster-2016.pdf
Binary file not shown.
Binary file removed references/fiw-ne-cv-workshop-2016.pdf
Binary file not shown.
Binary file removed references/fiw-poster-2017.pdf
Binary file not shown.
Binary file removed references/fiw_acm-mm-2016.pdf
Binary file not shown.
Binary file removed references/fiw_mdml_FG2017.pdf
Binary file not shown.
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
dlib
torchvision
matplotlib
torch
Pillow
scipy
scikit-learn
Expand Down
20 changes: 13 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
# -*- coding: utf-8 -*-

from setuptools import setup, find_packages
import setuptools


with open('README.md') as f:
with open('README.md', 'r') as f:
readme = f.read()

with open('LICENSE') as f:
with open('LICENSE', 'r') as f:
license = f.read()

setup(
name='FIW_KRT',
version='1.0.0',
name='fiwtools',
version='0.1.0',
description='Families In the WIld: A Kinship Recogntion Toolbox.',
long_description=readme,
author='Joseph Robinson',
author_email='robinson.jo@husky.neu.edu',
url='https://github.com/huskyjo/FIW_KRT',
url='https://github.com/visionjo/FIW_KRT',
packages=setuptools.find_packages(),
license=license,
packages=find_packages(exclude=('tests', 'docs'))
# packages=find_packages(exclude=('tests', 'docs'))
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
)

98 changes: 98 additions & 0 deletions sphereface_rfiw_baseline/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
<<<<<<< HEAD
from torch.utils import data
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
Expand All @@ -18,10 +19,38 @@ def __init__(self, root_dir, labels_path, n_classes, transform):
self.transform = transform
self.train_dataset = []
self.preprocess()
=======
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import csv


class FIW_DBBase(Dataset):
def __init__(self, root_dir, labels_path, n_classes=300, transform=None):
self.root_dir = root_dir
self.labels_path = labels_path
self.n_classes = n_classes
self.transform = transform
self.pairs = []

def __len__(self):
return len(self.pairs)

def __getitem__(self, idx):
pass



class FIW_Train(FIW_DBBase):
"""Dataset class for FIW Training Set"""
>>>>>>> d5d57c0ffdab6a2eac16d8809ae66bd6ab8f5f19

def preprocess(self):
"""Process the labels file"""
lines = [line.rstrip() for line in open(self.labels_path, 'r')]
<<<<<<< HEAD

for l in lines:
spt = l.split()
Expand Down Expand Up @@ -65,10 +94,45 @@ def __getitem__(self, index):
path_a, path_b, label = self.val_dataset[index]
img_a = self.transform(Image.open(os.path.join(self.base_dir, path_a)))
img_b = self.transform(Image.open(os.path.join(self.base_dir, path_b)))
=======
for l in lines:
spt = l.split()
fname = spt[0]
label = int(spt[1])
# label_vec = [1 if i == label else 0 for i in range(self.n_classes)]
self.pairs.append([fname, label])

def __getitem__(self, index):
"""Return an image"""
filename, label = self.pairs[index]
impath = os.path.join(self.root_dir, filename)
image = Image.open(impath)
return self.transform(image), label



class FIW_Val(FIW_DBBase):
"""Dataset class for FIW Validation Set"""

def preprocess(self):
"""Process the pair CSVs"""
with open(self.labels_path, 'r') as f:
re = csv.reader(f)
lines = list(re)

self.pairs = [(l[2], l[3], bool(int(l[1]))) for l in lines]

def __getitem__(self, index):
"""Return a pair"""
path_a, path_b, label = self.pairs[index]
img_a = self.transform(Image.open(os.path.join(self.root_dir, path_a)))
img_b = self.transform(Image.open(os.path.join(self.root_dir, path_b)))
>>>>>>> d5d57c0ffdab6a2eac16d8809ae66bd6ab8f5f19
return (img_a, img_b), label

def __len__(self):
"""Return the number of images."""
<<<<<<< HEAD
return len(self.val_dataset)

def get_train_loader(image_dir, labels_path='train/train.label', n_classes=300, image_size=(112, 96), batch_size=16, num_workers=1):
Expand All @@ -86,11 +150,32 @@ def get_train_loader(image_dir, labels_path='train/train.label', n_classes=300,
dataset = FIW_Train(image_dir, labels_path, n_classes, transform)

data_loader = data.DataLoader(dataset=dataset,
=======
return len(self.pairs)





def get_train_loader(image_dir, labels_path='train/train.label', n_classes=300, image_size=(112, 96), batch_size=16,
num_workers=1):
"""Build and return a data loader for the training set."""
transform = T.Compose([T.RandomHorizontalFlip(),
T.Resize(image_size),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset = FIW_Train(image_dir, labels_path, n_classes=n_classes, transform=transform)
dataset.preprocess()
data_loader = DataLoader(dataset=dataset,
>>>>>>> d5d57c0ffdab6a2eac16d8809ae66bd6ab8f5f19
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader

<<<<<<< HEAD
def get_val_loader(base_dir, csv_path, image_size=(112, 96), batch_size=128, num_workers=1):
"""Build and return a data loader for a split in the validation set."""
transform = []
Expand All @@ -103,6 +188,19 @@ def get_val_loader(base_dir, csv_path, image_size=(112, 96), batch_size=128, num
dataset = FIW_Val(base_dir, csv_path, transform)

data_loader = data.DataLoader(dataset=dataset,
=======

def get_val_loader(base_dir, csv_path, image_size=(112, 96), batch_size=128, num_workers=1):
"""Build and return a data loader for a split in the validation set."""
transform = T.Compose([T.Resize(image_size),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset = FIW_Val(base_dir, csv_path, transform=transform)
dataset.preprocess()
data_loader = DataLoader(dataset=dataset,
>>>>>>> d5d57c0ffdab6a2eac16d8809ae66bd6ab8f5f19
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
Expand Down
114 changes: 114 additions & 0 deletions sphereface_rfiw_baseline/evaluate_rfiw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import argparse
import net_sphere
import os
from torchtools import cuda, TorchTools, Tensor
from data_loader import get_val_loader
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from torch.autograd import Variable
from fiwtools.utils.io import sys_home
import torch.nn.functional as F
do_plot = True


def initialize_roc_plot(ax, lw=2):
ax.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
# ax.xlabel('False Positive Rate')
# ax.ylabel('True Positive Rate')
# ax.title('Receiver operating characteristic example')
# ax.legend(loc="lower right")
return ax


def generate_roc(ax, tp_array, fp_array, roc_auc, color='darkred', lw=2, label='ROC curve (area = %0.2f)', init=False, fname=''):


# if init:
# initialize_roc_plot(lw)
ax.plot(fp_array, tp_array, color=color, lw=lw, label=label % roc_auc)
# if len(fname) > 0:
# plt.savefig(fname)
def validate(net, data_loader, ax):
print('Begin validation')
net.eval()
y_labels = []
distances = []
ii=0
for pairs, labels in iter(data_loader):
if ii > 10:
break
img_a = Variable(pairs[0]).type(Tensor)
img_b = Variable(pairs[1]).type(Tensor)

_, embs_a = net(img_a)
_, embs_b = net(img_b)

embs_a = embs_a.data
embs_b = embs_b.data
cos_dis = F.cosine_similarity(embs_a, embs_b)
distances += list(cos_dis.data.cpu().numpy())

y_labels += list(labels.numpy())
# ii += 1

dist_array = np.array(distances)
y_array = np.array(y_labels)

fpr, tpr, thresh = roc_curve(y_array, dist_array)
roc_auc = auc(fpr, tpr)

if do_plot:
initialize_roc_plot(ax)
generate_roc(ax, tpr, fpr, roc_auc, color='darkred')
# fpr_micro, tpr_micro, _ = roc_curve(y_array.ravel(), dist_array.ravel())
# roc_auc_micro = auc(fpr_micro, tpr_micro)
#
# generate_roc(tpr_micro, fpr_micro, roc_auc_micro, color='darkorange')

# plt.show()

return roc_auc

do_types = np.linspace(0, 6, 7).astype(np.uint8)
types = ['bb', 'ss', 'sibs', 'fd', 'fs', 'md', 'ms']

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='FIW Sphereface Baseline')
parser.add_argument('--type', '-t', default='bb', type=str, help='relationship type (None processes entire directory)')
parser.add_argument('--batch_size', default=32, type=int, help='training batch size')
parser.add_argument('--modelpath', default='finetuned/checkpoint.pth.tar', type=str,
help='the pretrained model to point to')
parser.add_argument('--label_dir', '-l', type=str, default=sys_home() + '/datasets/FIW/RFIW/val/pairs/',
help='Root directory of data (assumed to containing pairs list labels)')
parser.add_argument('--data_dir', '-d', type=str, default=sys_home() + '/datasets/FIW/RFIW/val/',
help='Root directory of data (assumed to contain valdata)')

args = parser.parse_args()

net = net_sphere.sphere20a(classnum=300)

if cuda:
net.cuda()

epoch, bess_acc = TorchTools.load_checkpoint(net, f_weights=args.modelpath)

ncols = int(np.ceil(len(do_types) / 2))
nrows = 2
f, axes = plt.subplots(nrows, ncols, sharex='all', sharey='all')

for i, id in enumerate(do_types):
if i < ncols:
ax = axes[0, i]
else:
ax = axes[1, i - ncols]
csv_file = os.path.join(args.label_dir, types[id] + '_val.csv')
loader = get_val_loader(args.data_dir, csv_file)
# f.subplot()
auc_score = validate(net, loader, ax)

print('{} pairs: {} (auc)'.format(types[id], auc_score))

plt.savefig('roc.png')
Loading

0 comments on commit 9fb2780

Please sign in to comment.