-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcityscapes_dataset.py
45 lines (38 loc) · 1.48 KB
/
cityscapes_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os.path as osp
import numpy as np
from torch.utils import data
from PIL import Image
class CityscapesDataSet(data.Dataset):
def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255, set='val'):
self.root = root
self.list_path = list_path
self.crop_size = crop_size
self.scale = scale
self.ignore_label = ignore_label
self.mean = mean
self.is_mirror = mirror
self.img_ids = [i_id.strip() for i_id in open(list_path)]
if not max_iters==None:
self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
self.files = []
self.set = set
for name in self.img_ids:
img_file = osp.join(self.root, "leftImg8bit/%s/%s" % (self.set, name))
self.files.append({
"img": img_file,
"name": name
})
def __len__(self):
return len(self.files)
def __getitem__(self, index):
datafiles = self.files[index]
image = Image.open(datafiles["img"]).convert('RGB')
name = datafiles["name"]
# resize
image = image.resize(self.crop_size, Image.BICUBIC)
image = np.asarray(image, np.float32)
size = image.shape
image = image[:, :, ::-1] # change to BGR
image -= self.mean
image = image.transpose((2, 0, 1))
return image.copy(), np.array(size), name