-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsynthia_dataset.py
63 lines (50 loc) · 2.13 KB
/
synthia_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os.path as osp
import numpy as np
from torch.utils import data
from PIL import Image
import cv2
class SYNTHIA_DataSet(data.Dataset):
def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255):
self.root = root
self.list_path = list_path
self.crop_size = crop_size
self.scale = scale
self.ignore_label = ignore_label
self.mean = mean
self.is_mirror = mirror
self.img_ids = [i_id.strip() for i_id in open(list_path)]
if not max_iters==None:
self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
self.files = []
self.id_to_trainid = {3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5,
15: 6, 9: 7, 6: 8, 16: 9, 1: 10, 10: 11, 17: 12,
8: 13, 18: 14, 19: 15, 20: 16, 12: 17, 11: 18}
for name in self.img_ids:
img_file = osp.join(self.root, "RGB/%s" % name)
label_file = osp.join(self.root, "GT/LABELS/%s" % name)
self.files.append({
"img": img_file,
"label": label_file,
"name": name
})
def __len__(self):
return len(self.files)
def __getitem__(self, index):
datafiles = self.files[index]
image = Image.open(datafiles["img"]).convert('RGB')
label = cv2.imread(datafiles["label"], -1)
name = datafiles["name"]
# resize
image = image.resize(self.crop_size, Image.BICUBIC)
label = cv2.resize(label, self.crop_size)
image = np.asarray(image, np.float32)
label = np.asarray(label[:,:,2])
# re-assign labels to match the format of Cityscapes
label_copy = 255 * np.ones(label.shape, dtype=np.float32)
for k, v in self.id_to_trainid.items():
label_copy[label == k] = v
size = image.shape
image = image[:, :, ::-1] # change to BGR
image -= self.mean
image = image.transpose((2, 0, 1))
return image.copy(), label_copy.copy(), np.array(size), name