-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestModel.py
98 lines (85 loc) · 2.65 KB
/
TestModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import print_function
import sys
from PIL import Image
import os.path
from Parameters import Params
from torchvision import transforms
from torchvision import utils as tv
from Dataset import ImageSet
import torch
from torch.utils.data import DataLoader
from network_elements import Mapping
from torch.autograd import Variable
args = sys.argv;
path = args[-1];
model = args[-2];
p = Params();
p.model = model;
test_transforms = [];
if "-r" in args[1:-2]:
p.dir = 1
if "-o" in args[1:-2]:
outputInd = args.index("-o");
p.outputPath = args[outputInd+1];
if not os.path.exists(p.outputPath):
os.makedirs(p.outputPath);
else:
p.outputPath = "./outputs/"+model+"/";
if not os.path.exists(p.outputPath):
os.makedirs(p.outputPath);
if (p.dir):
tempPath = path.split('/')
print(tempPath)
if (path[-1] == '/'):
p.outputPath += tempPath[-2] + "/";
else:
p.outputPath += tempPath[-1] + "/";
if not os.path.exists(p.outputPath):
os.mkdir(p.outputPath);
if "-d" in args[1:-2]:
outputInd = args.index("-d");
p.direction = args[outputInd+1];
else:
p.direction = "yx";
if "-s" in args[1:-2]:
outputInd = args.index("-s");
p.scale = float(args[outputInd+1]);
else:
p.scale = 1.0;
print("Model under test is " + p.model);
print("The output path is "+p.outputPath);
images = ImageSet();
test_transforms += [transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))];
if (p.dir):
if os.path.exists(path):
print("Processing set in directory: "+path);
set
else:
print("No directory exists at path: "+path, file=sys.stderr);
images.loadImageSet(path, test_transforms, "test", p.scale);
else:
if os.path.exists(path):
print("Processing file at path: " + path);
else:
print("No file exists at path: "+path, file=sys.stderr);
images.loadImage(path, test_transforms, "test", p.scale);
imgLoader = DataLoader(images, 1, shuffle=False);
if (p.direction == 'xy'):
modelfile = torch.load("./model/"+p.model+"/G.data");
else:
modelfile = torch.load("./model/"+p.model+"/F.data");
m19 = modelfile['model.19.weight'];
im_size = m19.size(0);
F = Mapping(3,3,im_size);
F.cuda();
F.load_state_dict(modelfile);
F.eval();
for i, img in enumerate(imgLoader):
imgsize = img['img'].size();
imname = img['path']
print(imname)
img_gpu = torch.cuda.FloatTensor(1,3,imgsize[2],imgsize[3]);
img_var = Variable(img_gpu.copy_(img['img']))
result = 0.5*(F(img_var).data+1.0);
tv.save_image(result, p.outputPath+'out_' + imname[0]);