-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDetailsNet.py
153 lines (116 loc) · 5.34 KB
/
DetailsNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# %% import library
import torch
import torch.nn as nn
import torch.nn.functional as F
# %% submodules
class CL(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=4, stride=2, padding=1):
"""
It consists of the 4x4 convolutions with stride=2, padding=1, each followed by
a leaky rectified linear unit (Leaky ReLU)
:param input_channel: input channel size
:param output_channel: output channel size
:param kernel_size: kernel size of module
:param stride: stride of module
:param padding: padding of module
"""
assert (input_channel > 0 and output_channel > 0)
super(CL, self).__init__()
layers = [nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding),
nn.LeakyReLU(0.2, inplace=False)]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class CBL(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=4, stride=2, padding=1):
"""
It consists of the 4x4 convolutions with stride=2, padding=1, and a batch normalization, followed by
a leaky rectified linear unit (ReLU)
:param input_channel: input channel size
:param output_channel: output channel size
:param kernel_size: kernel size of module
:param stride: stride of module
:param padding: padding of module
"""
assert (input_channel > 0 and output_channel > 0)
super(CBL, self).__init__()
layers = [nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(num_features=output_channel), nn.LeakyReLU(0.2, inplace=False)]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class C(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=3, stride=1, padding=1):
"""
At the final layer, a 3x3 convolution is used to map feature vector to the desired
number of classes.
:param input_channel: input channel size
:param output_channel: output channel size
:param kernel_size: kernel size of module
:param stride: stride of module
:param padding: padding of module
"""
super(C, self).__init__()
layers = [nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, padding=padding),
nn.Tanh()]
# In the layers above, we have to use Tanh to map output between [-1, 1], but because we did not apply any
# normalization at first to map our images from [0, 1] to [-1, 1], we use Sigmoid
self.layer = nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
# %% residual block
class ResidualBlock(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=3, stride=1, padding=1):
"""
A residual block contains a sequence of CBL and CL classes with same size of stride, padding and kernel size.
:param input_channel: number of input channels of input images to network.
:param output_channel: number of output channels of output images of network.
:param stride: stride size of CBL and CL modules
:param kernel_size: kernel_size of CBL and CL modules
:param padding: padding size of CBL and CL modules
"""
super(ResidualBlock, self).__init__()
self.input_channel = input_channel
self.output_channel = output_channel
# blocks
self.cbl = CBL(input_channel, output_channel, kernel_size, stride, padding)
self.cl = CL(output_channel, output_channel, kernel_size, stride, padding)
def forward(self, x):
out = self.cbl(x)
out = self.cl(out)
return out
# %% details net
class DetailsNet(nn.Module):
def __init__(self, input_channels=32, output_channels=3):
"""
The generator of GAN networks contains repeated residual blocks and C block at the end.
:param input_channels: number of input channels of input images to network. Actually, it is latent vector length
which is fusion of I<sub>h</sub>, I<sub>a</sub>, I<sub>c</sub>, I<sub>e</sub> vectors
which is called I<sub>f</sub>.
:param output_channels: number of output channels of output images of network.
"""
super(DetailsNet, self).__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.block0 = ResidualBlock(input_channel=self.input_channels, output_channel=64)
self.block1 = ResidualBlock(input_channel=64, output_channel=64)
self.block2 = ResidualBlock(input_channel=64, output_channel=64)
self.block3 = ResidualBlock(input_channel=64, output_channel=64)
self.final = C(input_channel=64, output_channel=self.output_channels)
def forward(self, x):
x = self.block0(x)
residual1 = x
x = self.block1(x)
x += residual1
residual2 = x
x = self.block3(x)
x += residual2
residual3 = x
x = self.block3(x)
x += residual3
x = self.final(x)
return x
# %% tests
# z = torch.randn(size=(1, 32, 256, 256))
# details_net = DetailsNet()
# zo = details_net(z)