-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathLoadModel.py
90 lines (78 loc) · 3.4 KB
/
LoadModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
from torch import nn
import torch
from torchvision import models, transforms, datasets
import torch.nn.functional as F
import pretrainedmodels
from torchsummary import summary
from config import pretrained_model
import pdb
class MainModel(nn.Module):
def __init__(self, config):
super(MainModel, self).__init__()
self.use_dcl = config.use_dcl
self.num_classes = config.numcls
self.backbone_arch = config.backbone
self.use_Asoftmax = config.use_Asoftmax
print(self.backbone_arch)
if self.backbone_arch in dir(models):
self.model = getattr(models, self.backbone_arch)()
if self.backbone_arch in pretrained_model:
self.model.load_state_dict(torch.load(pretrained_model[self.backbone_arch]))
else:
if self.backbone_arch in pretrained_model:
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000, pretrained=None)
else:
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000)
if self.backbone_arch == 'resnet50' or self.backbone_arch == 'se_resnet50':
self.model = nn.Sequential(*list(self.model.children())[:-2])
if self.backbone_arch == 'senet154':
self.model = nn.Sequential(*list(self.model.children())[:-3])
if self.backbone_arch == 'se_resnext101_32x4d':
self.model = nn.Sequential(*list(self.model.children())[:-2])
if self.backbone_arch == 'se_resnet101':
self.model = nn.Sequential(*list(self.model.children())[:-2])
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
if self.use_dcl:
if config.cls_2:
self.classifier_swap = nn.Linear(2048, 2, bias=False)
if config.cls_2xmul:
self.classifier_swap = nn.Linear(2048, 2*self.num_classes, bias=False)
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=True)
self.avgpool2 = nn.AvgPool2d(2, stride=2)
if self.use_Asoftmax:
self.Aclassifier = AngleLinear(2048, self.num_classes, bias=False)
def forward(self, x, last_cont=None):
x = self.model(x)
# print('Feature x:')
# print(x.shape)
if self.use_dcl:
mask = self.Convmask(x)
# print("mask Convmask",mask.shape)
mask = self.avgpool2(mask)
# print("mask avgpool2",mask.shape)
mask = torch.tanh(mask)
mask = mask.view(mask.size(0), -1)
# print("mask", mask.shape)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
out = []
out.append(self.classifier(x))
if self.use_dcl:
out.append(self.classifier_swap(x))
out.append(mask)
if self.use_Asoftmax:
if last_cont is None:
x_size = x.size(0)
out.append(self.Aclassifier(x[0:x_size:2]))
else:
last_x = self.model(last_cont)
last_x = self.avgpool(last_x)
last_x = last_x.view(last_x.size(0), -1)
out.append(self.Aclassifier(last_x))
# print("out:",len(out), out[0].shape, out[1].shape, out[2].shape)
return out
# if __name__ == '__main__':
# model = MainModel()
# summary(model, (3, 112, 112))