-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
274 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
#!/usr/bin/env python3 | ||
"""Process an image with the trained neural network | ||
Usage: | ||
demo.py [options] <yaml-config> <checkpoint> <image> | ||
demo.py (-h | --help ) | ||
Arguments: | ||
<yaml-config> Path to the yaml hyper-parameter file | ||
<checkpoint> Path to the checkpoint | ||
<image> Path to the directory containing processed images | ||
Options: | ||
-h --help Show this screen. | ||
-d --devices <devices> Comma seperated GPU devices [default: 0] | ||
""" | ||
|
||
import os | ||
import os.path as osp | ||
import pprint | ||
import random | ||
|
||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import skimage.io | ||
import skimage.transform | ||
import torch | ||
import yaml | ||
from docopt import docopt | ||
|
||
import lcnn | ||
from lcnn.config import C, M | ||
from lcnn.models.line_vectorizer import LineVectorizer | ||
from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner | ||
from lcnn.postprocess import postprocess | ||
from lcnn.utils import recursive_to | ||
|
||
PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} | ||
cmap = plt.get_cmap("jet") | ||
norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0) | ||
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) | ||
sm.set_array([]) | ||
|
||
|
||
def c(x): | ||
return sm.to_rgba(x) | ||
|
||
|
||
def main(): | ||
args = docopt(__doc__) | ||
config_file = args["<yaml-config>"] or "config/wireframe.yaml" | ||
C.update(C.from_yaml(filename=config_file)) | ||
M.update(C.model) | ||
pprint.pprint(C, indent=4) | ||
|
||
random.seed(0) | ||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
device_name = "cpu" | ||
os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"] | ||
if torch.cuda.is_available(): | ||
device_name = "cuda" | ||
torch.backends.cudnn.deterministic = True | ||
torch.cuda.manual_seed(0) | ||
print("Let's use", torch.cuda.device_count(), "GPU(s)!") | ||
else: | ||
print("CUDA is not available") | ||
device = torch.device(device_name) | ||
checkpoint = torch.load(args["<checkpoint>"], map_location=device) | ||
|
||
# Load model | ||
model = lcnn.models.hg( | ||
depth=M.depth, | ||
head=lambda c_in, c_out: MultitaskHead(c_in, c_out), | ||
num_stacks=M.num_stacks, | ||
num_blocks=M.num_blocks, | ||
num_classes=sum(sum(M.head_size, [])), | ||
) | ||
model = MultitaskLearner(model) | ||
model = LineVectorizer(model) | ||
model.load_state_dict(checkpoint["model_state_dict"]) | ||
model = model.to(device) | ||
model.eval() | ||
|
||
im = skimage.io.imread(args["<image>"])[:, :, :3] | ||
im_resized = skimage.transform.resize(im, (512, 512)) * 255 | ||
image = (im_resized - M.image.mean) / M.image.stddev | ||
image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float() | ||
with torch.no_grad(): | ||
input_dict = { | ||
"image": image.to(device), | ||
"meta": [ | ||
{ | ||
"junc": torch.zeros(1, 2).to(device), | ||
"jtyp": torch.zeros(1, dtype=torch.uint8).to(device), | ||
"Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device), | ||
"Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device), | ||
} | ||
], | ||
"target": { | ||
"jmap": torch.zeros([1, 1, 128, 128]).to(device), | ||
"joff": torch.zeros([1, 1, 2, 128, 128]).to(device), | ||
}, | ||
"do_evaluation": True, | ||
} | ||
H = model(input_dict)["preds"] | ||
|
||
lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] | ||
scores = H["score"][0].cpu().numpy() | ||
for i in range(1, len(lines)): | ||
if (lines[i] == lines[0]).all(): | ||
lines = lines[:i] | ||
scores = scores[:i] | ||
break | ||
|
||
# postprocess lines to remove overlapped lines | ||
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 | ||
nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) | ||
|
||
plt.gca().set_axis_off() | ||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) | ||
plt.margins(0, 0) | ||
plt.gca().xaxis.set_major_locator(plt.NullLocator()) | ||
plt.gca().yaxis.set_major_locator(plt.NullLocator()) | ||
for i, t in enumerate([0.95, 0.96, 0.97, 0.98, 0.99]): | ||
for (a, b), s in zip(nlines, nscores): | ||
if s < t: | ||
continue | ||
plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) | ||
plt.scatter(a[1], a[0], **PLTOPTS) | ||
plt.scatter(b[1], b[0], **PLTOPTS) | ||
plt.imshow(im) | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import numpy as np | ||
|
||
|
||
def pline(x1, y1, x2, y2, x, y): | ||
px = x2 - x1 | ||
py = y2 - y1 | ||
dd = px * px + py * py | ||
u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd)) | ||
dx = x1 + u * px - x | ||
dy = y1 + u * py - y | ||
return dx * dx + dy * dy | ||
|
||
|
||
def psegment(x1, y1, x2, y2, x, y): | ||
px = x2 - x1 | ||
py = y2 - y1 | ||
dd = px * px + py * py | ||
u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0) | ||
dx = x1 + u * px - x | ||
dy = y1 + u * py - y | ||
return dx * dx + dy * dy | ||
|
||
|
||
def plambda(x1, y1, x2, y2, x, y): | ||
px = x2 - x1 | ||
py = y2 - y1 | ||
dd = px * px + py * py | ||
return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd)) | ||
|
||
|
||
def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False): | ||
nlines, nscores = [], [] | ||
for (p, q), score in zip(lines, scores): | ||
start, end = 0, 1 | ||
for a, b in nlines: | ||
if ( | ||
min( | ||
max(pline(*p, *q, *a), pline(*p, *q, *b)), | ||
max(pline(*a, *b, *p), pline(*a, *b, *q)), | ||
) | ||
> threshold ** 2 | ||
): | ||
continue | ||
lambda_a = plambda(*p, *q, *a) | ||
lambda_b = plambda(*p, *q, *b) | ||
if lambda_a > lambda_b: | ||
lambda_a, lambda_b = lambda_b, lambda_a | ||
lambda_a -= tol | ||
lambda_b += tol | ||
|
||
# case 1: skip (if not do_clip) | ||
if start < lambda_a and lambda_b < end: | ||
continue | ||
|
||
# not intersect | ||
if lambda_b < start or lambda_a > end: | ||
continue | ||
|
||
# cover | ||
if lambda_a <= start and end <= lambda_b: | ||
start = 10 | ||
break | ||
|
||
# case 2 & 3: | ||
if lambda_a <= start and start <= lambda_b: | ||
start = lambda_b | ||
if lambda_a <= end and end <= lambda_b: | ||
end = lambda_a | ||
|
||
if start >= end: | ||
break | ||
|
||
if start >= end: | ||
continue | ||
nlines.append(np.array([p + (q - p) * start, p + (q - p) * end])) | ||
nscores.append(score) | ||
return np.array(nlines), np.array(nscores) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.