-
Notifications
You must be signed in to change notification settings - Fork 283
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a6728dc
commit 205f3bf
Showing
5 changed files
with
144 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,118 +1,110 @@ | ||
import os | ||
import sys | ||
from collections import OrderedDict | ||
from os import walk | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
# | ||
# Author: Kazuto Nakashima | ||
# URL: http://kazuto1011.github.io | ||
# Created: 2017-11-03 | ||
|
||
|
||
import argparse | ||
import os.path as osp | ||
|
||
import cv2 | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import scipy | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import yaml | ||
from tensorboardX import SummaryWriter | ||
from torch.autograd import Variable | ||
|
||
from models import DeepLab | ||
import torchvision.models as models | ||
from docopt import docopt | ||
|
||
docstr = """Evaluate ResNet-DeepLab trained on scenes (VOC 2012),a total of 21 labels including background | ||
Usage: | ||
evalpyt.py [options] | ||
Options: | ||
-h, --help Print this message | ||
--visualize view outputs of each sketch | ||
--snapPrefix=<str> Snapshot [default: VOC12_scenes_] | ||
--testGTpath=<str> Ground truth path prefix [default: data/gt/] | ||
--testIMpath=<str> Sketch images path prefix [default: data/img/] | ||
--NoLabels=<int> The number of different labels in training data, VOC has 21 labels, including background [default: 21] | ||
--gpu0=<int> GPU number [default: 0] | ||
""" | ||
|
||
args = docopt(docstr, version='v0.1') | ||
print args | ||
|
||
|
||
def get_iou(pred, gt): | ||
if pred.shape != gt.shape: | ||
print 'pred shape', pred.shape, 'gt shape', gt.shape | ||
assert(pred.shape == gt.shape) | ||
gt = gt.astype(np.float32) | ||
pred = pred.astype(np.float32) | ||
|
||
max_label = int(args['--NoLabels']) - 1 # labels from 0,1, ... 20(for VOC) | ||
count = np.zeros((max_label + 1,)) | ||
for j in range(max_label + 1): | ||
x = np.where(pred == j) | ||
p_idx_j = set(zip(x[0].tolist(), x[1].tolist())) | ||
x = np.where(gt == j) | ||
GT_idx_j = set(zip(x[0].tolist(), x[1].tolist())) | ||
# pdb.set_trace() | ||
n_jj = set.intersection(p_idx_j, GT_idx_j) | ||
u_jj = set.union(p_idx_j, GT_idx_j) | ||
|
||
if len(GT_idx_j) != 0: | ||
count[j] = float(len(n_jj)) / float(len(u_jj)) | ||
|
||
result_class = count | ||
Aiou = np.sum(result_class[:]) / float(len(np.unique(gt))) | ||
|
||
return Aiou | ||
|
||
|
||
gpu0 = int(args['--gpu0']) | ||
im_path = args['--testIMpath'] | ||
model = DeepLab(int(args['--NoLabels'])) | ||
model.eval() | ||
counter = 0 | ||
model.cuda(gpu0) | ||
snapPrefix = args['--snapPrefix'] | ||
gt_path = args['--testGTpath'] | ||
img_list = open('data/list/val.txt').readlines() | ||
|
||
# TODO set the (different iteration)models that you want to evaluate on. Models are saved during training after every 1000 iters by default. | ||
for iter in range(1, 21): | ||
saved_state_dict = torch.load(os.path.join( | ||
'data/snapshots/', snapPrefix + str(iter) + '000.pth')) | ||
if counter == 0: | ||
print snapPrefix | ||
counter += 1 | ||
model.load_state_dict(saved_state_dict) | ||
|
||
pytorch_list = [] | ||
for i in img_list: | ||
img = np.zeros((513, 513, 3)) | ||
|
||
img_temp = cv2.imread(os.path.join( | ||
im_path, i[:-1] + '.jpg')).astype(float) | ||
img_original = img_temp | ||
img_temp[:, :, 0] = img_temp[:, :, 0] - 104.008 | ||
img_temp[:, :, 1] = img_temp[:, :, 1] - 116.669 | ||
img_temp[:, :, 2] = img_temp[:, :, 2] - 122.675 | ||
img[:img_temp.shape[0], :img_temp.shape[1], :] = img_temp | ||
gt = cv2.imread(os.path.join(gt_path, i[:-1] + '.png'), 0) | ||
gt[gt == 255] = 0 | ||
|
||
output = model(Variable(torch.from_numpy(img[np.newaxis, :].transpose( | ||
0, 3, 1, 2)).float(), volatile=True).cuda(gpu0)) | ||
interp = nn.UpsamplingBilinear2d(size=(513, 513)) | ||
output = interp(output[3]).cpu().data[0].numpy() | ||
output = output[:, :img_temp.shape[0], :img_temp.shape[1]] | ||
|
||
output = output.transpose(1, 2, 0) | ||
from torchnet.meter import MovingAverageValueMeter | ||
from tqdm import tqdm | ||
|
||
from libs.datasets import get_dataset | ||
from libs.models import DeepLab | ||
from libs.utils import CrossEntropyLoss2d, scores | ||
|
||
|
||
def main(args): | ||
# Configuration | ||
with open(args.config) as f: | ||
config = yaml.load(f) | ||
|
||
image_size = (config['image']['size']['test'], | ||
config['image']['size']['test']) | ||
n_classes = config['dataset'][args.dataset]['n_classes'] | ||
|
||
# Dataset | ||
dataset = get_dataset(args.dataset)( | ||
root=config['dataset'][args.dataset]['root'], | ||
split='test', | ||
image_size=image_size, | ||
scale=False, | ||
flip=False, | ||
preload=False | ||
) | ||
|
||
# DataLoader | ||
loader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=args.batch_size, | ||
num_workers=config['num_workers'], | ||
shuffle=False | ||
) | ||
loader_iter = iter(loader) | ||
|
||
checkpoint = torch.load(args.checkpoint, | ||
map_location=lambda storage, | ||
loc: storage) | ||
state_dict = checkpoint['weight'] | ||
print('Result after {} iterations'.format(checkpoint['iteration'])) | ||
|
||
# Model | ||
model = DeepLab(n_classes=n_classes) | ||
model.load_state_dict(state_dict) | ||
model.eval() | ||
if args.cuda: | ||
model.cuda() | ||
|
||
targets, outputs = [], [] | ||
for data, target in tqdm(loader, total=len(loader), | ||
leave=False, dynamic_ncols=True): | ||
# Image | ||
data = data.cuda() if args.cuda else data | ||
data = Variable(data, volatile=True) | ||
|
||
# Forward propagation | ||
output = model(data) | ||
output = F.upsample(output[3], size=image_size, mode='bilinear') | ||
output = output[0].cpu().data.numpy().transpose(1, 2, 0) | ||
output = np.argmax(output, axis=2) | ||
if args['--visualize']: | ||
plt.subplot(3, 1, 1) | ||
plt.imshow(img_original) | ||
plt.subplot(3, 1, 2) | ||
plt.imshow(gt) | ||
plt.subplot(3, 1, 3) | ||
plt.imshow(output) | ||
plt.show() | ||
|
||
iou_pytorch = get_iou(output, gt) | ||
pytorch_list.append(iou_pytorch) | ||
|
||
print 'pytorch', iter, np.sum(np.asarray(pytorch_list)) / len(pytorch_list) | ||
target = target.numpy() | ||
|
||
for o, t in zip(output, target): | ||
outputs.append(o) | ||
targets.append(t) | ||
|
||
score, class_iou = scores(targets, outputs, n_class=n_classes) | ||
|
||
for k, v in score.items(): | ||
print k, v | ||
|
||
for i in range(n_classes): | ||
print i, class_iou[i] | ||
|
||
|
||
if __name__ == '__main__': | ||
# Parsing arguments | ||
parser = argparse.ArgumentParser(description='') | ||
parser.add_argument('--no_cuda', action='store_true', default=False) | ||
parser.add_argument('--dataset', nargs='?', type=str, default='cocostuff') | ||
parser.add_argument('--config', type=str, default='config/default.yaml') | ||
parser.add_argument('--checkpoint', type=str, default=None) | ||
parser.add_argument('--batch_size', type=int, default=1) | ||
|
||
args = parser.parse_args() | ||
args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
for arg in vars(args): | ||
print('{0:20s}: {1}'.format(arg.rjust(20), getattr(args, arg))) | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from libs.utils.loss import CrossEntropyLoss2d | ||
from libs.utils.lr_scheduler import poly_lr_scheduler | ||
from libs.utils.metric import scores |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Originally written by wkentaro | ||
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py | ||
|
||
import numpy as np | ||
|
||
|
||
def _fast_hist(label_true, label_pred, n_class): | ||
mask = (label_true >= 0) & (label_true < n_class) | ||
hist = np.bincount( | ||
n_class * label_true[mask].astype(int) + | ||
label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) | ||
return hist | ||
|
||
|
||
def scores(label_trues, label_preds, n_class): | ||
"""Returns accuracy score evaluation result. | ||
- overall accuracy | ||
- mean accuracy | ||
- mean IU | ||
- fwavacc | ||
""" | ||
hist = np.zeros((n_class, n_class)) | ||
for lt, lp in zip(label_trues, label_preds): | ||
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) | ||
acc = np.diag(hist).sum() / hist.sum() | ||
acc_cls = np.diag(hist) / hist.sum(axis=1) | ||
acc_cls = np.nanmean(acc_cls) | ||
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) | ||
mean_iu = np.nanmean(iu) | ||
freq = hist.sum(axis=1) / hist.sum() | ||
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() | ||
cls_iu = dict(zip(range(n_class), iu)) | ||
|
||
return {'Overall Acc: \t': acc, | ||
'Mean Acc : \t': acc_cls, | ||
'FreqW Acc : \t': fwavacc, | ||
'Mean IoU : \t': mean_iu, }, cls_iu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters