-
Notifications
You must be signed in to change notification settings - Fork 95
/
datasets.py
95 lines (83 loc) · 3.57 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import glob
import json
import math
import os
import random
import numpy as np
import numpy.linalg as LA
import torch
from skimage import io
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from lcnn.config import M
class WireframeDataset(Dataset):
def __init__(self, rootdir, split):
self.rootdir = rootdir
filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
filelist.sort()
print(f"n{split}:", len(filelist))
self.split = split
self.filelist = filelist
def __len__(self):
return len(self.filelist)
def __getitem__(self, idx):
iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
image = io.imread(iname).astype(float)[:, :, :3]
if "a1" in self.filelist[idx]:
image = image[:, ::-1, :]
image = (image - M.image.mean) / M.image.stddev
image = np.rollaxis(image, 2).copy()
# npz["jmap"]: [J, H, W] Junction heat map
# npz["joff"]: [J, 2, H, W] Junction offset within each pixel
# npz["lmap"]: [H, W] Line heat map with anti-aliasing
# npz["junc"]: [Na, 3] Junction coordinates
# npz["Lpos"]: [M, 2] Positive lines represented with junction indices
# npz["Lneg"]: [M, 2] Negative lines represented with junction indices
# npz["lpos"]: [Np, 2, 3] Positive lines represented with junction coordinates
# npz["lneg"]: [Nn, 2, 3] Negative lines represented with junction coordinates
#
# For junc, lpos, and lneg that stores the junction coordinates, the last
# dimension is (y, x, t), where t represents the type of that junction.
with np.load(self.filelist[idx]) as npz:
target = {
name: torch.from_numpy(npz[name]).float()
for name in ["jmap", "joff", "lmap"]
}
lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
npos, nneg = len(lpos), len(lneg)
lpre = np.concatenate([lpos, lneg], 0)
for i in range(len(lpre)):
if random.random() > 0.5:
lpre[i] = lpre[i, ::-1]
ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
feat = [
lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
ldir * M.use_slop,
lpre[:, :, 2],
]
feat = np.concatenate(feat, 1)
meta = {
"junc": torch.from_numpy(npz["junc"][:, :2]),
"jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
"Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
"Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
"lpre": torch.from_numpy(lpre[:, :, :2]),
"lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
"lpre_feat": torch.from_numpy(feat),
}
return torch.from_numpy(image).float(), meta, target
def adjacency_matrix(self, n, link):
mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
link = torch.from_numpy(link)
if len(link) > 0:
mat[link[:, 0], link[:, 1]] = 1
mat[link[:, 1], link[:, 0]] = 1
return mat
def collate(batch):
return (
default_collate([b[0] for b in batch]),
[b[1] for b in batch],
default_collate([b[2] for b in batch]),
)