forked from jfzhang95/pytorch-deeplab-xception
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcombine_dbs.py
102 lines (82 loc) · 3.4 KB
/
combine_dbs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch.utils.data as data
class CombineDBs(data.Dataset):
def __init__(self, dataloaders, excluded=None):
self.dataloaders = dataloaders
self.excluded = excluded
self.im_ids = []
# Combine object lists
for dl in dataloaders:
for elem in dl.im_ids:
if elem not in self.im_ids:
self.im_ids.append(elem)
# Exclude
if excluded:
for dl in excluded:
for elem in dl.im_ids:
if elem in self.im_ids:
self.im_ids.remove(elem)
# Get object pointers
self.cat_list = []
self.im_list = []
new_im_ids = []
num_images = 0
for ii, dl in enumerate(dataloaders):
for jj, curr_im_id in enumerate(dl.im_ids):
if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids):
num_images += 1
new_im_ids.append(curr_im_id)
self.cat_list.append({'db_ii': ii, 'cat_ii': jj})
self.im_ids = new_im_ids
print('Combined number of images: {:d}'.format(num_images))
def __getitem__(self, index):
_db_ii = self.cat_list[index]["db_ii"]
_cat_ii = self.cat_list[index]['cat_ii']
sample = self.dataloaders[_db_ii].__getitem__(_cat_ii)
if 'meta' in sample.keys():
sample['meta']['db'] = str(self.dataloaders[_db_ii])
return sample
def __len__(self):
return len(self.cat_list)
def __str__(self):
include_db = [str(db) for db in self.dataloaders]
exclude_db = [str(db) for db in self.excluded]
return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db)
if __name__ == "__main__":
import matplotlib.pyplot as plt
from dataloaders import pascal
from dataloaders import sbd
import torch
import numpy as np
import dataloaders.custom_transforms as tr
from dataloaders.utils import decode_segmap
from torchvision import transforms
composed_transforms_tr = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomSized(512),
tr.RandomRotate(15),
tr.ToTensor()])
composed_transforms_ts = transforms.Compose([
tr.FixedResize(size=(512, 512)),
tr.ToTensor()])
pascal_voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr)
pascal_voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
for ii, sample in enumerate(dataloader):
for jj in range(sample["image"].size()[0]):
img = sample['image'].numpy()
gt = sample['label'].numpy()
tmp = np.array(gt[jj]).astype(np.uint8)
tmp = np.squeeze(tmp, axis=0)
segmap = decode_segmap(tmp, dataset='pascal')
img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)
plt.figure()
plt.title('display')
plt.subplot(211)
plt.imshow(img_tmp)
plt.subplot(212)
plt.imshow(segmap)
if ii == 1:
break
plt.show(block=True)