-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathic_neck.py
147 lines (131 loc) · 5.19 KB
/
ic_neck.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import NECKS
class CascadeFeatureFusion(BaseModule):
"""Cascade Feature Fusion Unit in ICNet.
Args:
low_channels (int): The number of input channels for
low resolution feature map.
high_channels (int): The number of input channels for
high resolution feature map.
out_channels (int): The number of output channels.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (Tensor): The output tensor of shape (N, out_channels, H, W).
x_low (Tensor): The output tensor of shape (N, out_channels, H, W)
for Cascade Label Guidance in auxiliary heads.
"""
def __init__(self,
low_channels,
high_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
super(CascadeFeatureFusion, self).__init__(init_cfg=init_cfg)
self.align_corners = align_corners
self.conv_low = ConvModule(
low_channels,
out_channels,
3,
padding=2,
dilation=2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_high = ConvModule(
high_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x_low, x_high):
x_low = resize(
x_low,
size=x_high.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
# Note: Different from original paper, `x_low` is underwent
# `self.conv_low` rather than another 1x1 conv classifier
# before being used for auxiliary head.
x_low = self.conv_low(x_low)
x_high = self.conv_high(x_high)
x = x_low + x_high
x = F.relu(x, inplace=True)
return x, x_low
@NECKS.register_module()
class ICNeck(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
This head is the implementation of `ICHead
<https://arxiv.org/abs/1704.08545>`_.
Args:
in_channels (int): The number of input image channels. Default: 3.
out_channels (int): The numbers of output feature channels.
Default: 128.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=(64, 256, 256),
out_channels=128,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
super(ICNeck, self).__init__(init_cfg=init_cfg)
assert len(in_channels) == 3, 'Length of input channels \
must be 3!'
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.cff_24 = CascadeFeatureFusion(
self.in_channels[2],
self.in_channels[1],
self.out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.cff_12 = CascadeFeatureFusion(
self.out_channels,
self.in_channels[0],
self.out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
def forward(self, inputs):
assert len(inputs) == 3, 'Length of input feature \
maps must be 3!'
x_sub1, x_sub2, x_sub4 = inputs
x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2)
x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1)
# Note: `x_cff_12` is used for decode_head,
# `x_24` and `x_12` are used for auxiliary head.
return x_24, x_12, x_cff_12