-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0e7ed6f
commit 2e9eaca
Showing
3 changed files
with
254 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import rearrange, repeat | ||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | ||
import timm | ||
from mamba_ssm import Mamba | ||
|
||
|
||
class ConvBNReLU(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): | ||
super(ConvBNReLU, self).__init__( | ||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, | ||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), | ||
norm_layer(out_channels), | ||
nn.ReLU6() | ||
) | ||
|
||
|
||
class ConvBN(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): | ||
super(ConvBN, self).__init__( | ||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, | ||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), | ||
norm_layer(out_channels) | ||
) | ||
|
||
|
||
class Conv(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False): | ||
super(Conv, self).__init__( | ||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, | ||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2) | ||
) | ||
|
||
|
||
class SeparableConvBNReLU(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, | ||
norm_layer=nn.BatchNorm2d): | ||
super(SeparableConvBNReLU, self).__init__( | ||
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, | ||
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, | ||
groups=in_channels, bias=False), | ||
norm_layer(in_channels), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), | ||
nn.ReLU6() | ||
) | ||
|
||
|
||
class SeparableConvBN(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, | ||
norm_layer=nn.BatchNorm2d): | ||
super(SeparableConvBN, self).__init__( | ||
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, | ||
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, | ||
groups=in_channels, bias=False), | ||
norm_layer(out_channels), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | ||
) | ||
|
||
|
||
class SeparableConv(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1): | ||
super(SeparableConv, self).__init__( | ||
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, | ||
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, | ||
groups=in_channels, bias=False), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | ||
) | ||
|
||
|
||
class MambaLayer(nn.Module): | ||
def __init__(self, in_chs=512, dim=128, d_state=16, d_conv=4, expand=2, last_feat_size=16): | ||
super().__init__() | ||
pool_scales = self.generate_arithmetic_sequence(1, last_feat_size, last_feat_size // 8) | ||
self.pool_len = len(pool_scales) | ||
self.pool_layers = nn.ModuleList() | ||
self.pool_layers.append(nn.Sequential( | ||
ConvBNReLU(in_chs, dim, kernel_size=1), | ||
nn.AdaptiveAvgPool2d(1) | ||
)) | ||
for pool_scale in pool_scales[1:]: | ||
self.pool_layers.append( | ||
nn.Sequential( | ||
nn.AdaptiveAvgPool2d(pool_scale), | ||
ConvBNReLU(in_chs, dim, kernel_size=1) | ||
)) | ||
self.mamba = Mamba( | ||
d_model=dim*self.pool_len+in_chs, # Model dimension d_model | ||
d_state=d_state, # SSM state expansion factor | ||
d_conv=d_conv, # Local convolution width | ||
expand=expand # Block expansion factor | ||
) | ||
|
||
def forward(self, x): # B, C, H, W | ||
res = x | ||
B, C, H, W = res.shape | ||
ppm_out = [res] | ||
for p in self.pool_layers: | ||
pool_out = p(x) | ||
pool_out = F.interpolate(pool_out, (H, W), mode='bilinear', align_corners=False) | ||
ppm_out.append(pool_out) | ||
x = torch.cat(ppm_out, dim=1) | ||
_, chs, _, _ = x.shape | ||
x = rearrange(x, 'b c h w -> b (h w) c', b=B, c=chs, h=H, w=W) | ||
x = self.mamba(x) | ||
x = x.transpose(2, 1).view(B, chs, H, W) | ||
return x | ||
|
||
def generate_arithmetic_sequence(self, start, stop, step): | ||
sequence = [] | ||
for i in range(start, stop, step): | ||
sequence.append(i) | ||
return sequence | ||
|
||
|
||
class ConvFFN(nn.Module): | ||
def __init__(self, in_ch=128, hidden_ch=512, out_ch=128, drop=0.): | ||
super(ConvFFN, self).__init__() | ||
self.conv = ConvBNReLU(in_ch, in_ch, kernel_size=3) | ||
self.fc1 = Conv(in_ch, hidden_ch, kernel_size=1) | ||
self.act = nn.GELU() | ||
self.fc2 = Conv(hidden_ch, out_ch, kernel_size=1) | ||
self.drop = nn.Dropout(drop) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.fc1(x) | ||
x = self.act(x) | ||
x = self.drop(x) | ||
x = self.fc2(x) | ||
x = self.drop(x) | ||
|
||
return x | ||
|
||
|
||
class Block(nn.Module): | ||
def __init__(self, in_chs=512, dim=128, hidden_ch=512, out_ch=128, drop=0.1, d_state=16, d_conv=4, expand=2, last_feat_size=16): | ||
super(Block, self).__init__() | ||
self.mamba = MambaLayer(in_chs=in_chs, dim=dim, d_state=d_state, d_conv=d_conv, expand=expand, last_feat_size=last_feat_size) | ||
self.conv_ffn = ConvFFN(in_ch=dim*self.mamba.pool_len+in_chs, hidden_ch=hidden_ch, out_ch=out_ch, drop=drop) | ||
|
||
def forward(self, x): | ||
x = self.mamba(x) | ||
x = self.conv_ffn(x) | ||
|
||
return x | ||
|
||
|
||
class Decoder(nn.Module): | ||
def __init__(self, encoder_channels=(64, 128, 256, 512), decoder_channels=128, num_classes=6, last_feat_size=16): | ||
super().__init__() | ||
self.b3 = Block(in_chs=encoder_channels[-1], dim=decoder_channels, last_feat_size=last_feat_size) | ||
self.up_conv = nn.Sequential(ConvBNReLU(decoder_channels, decoder_channels), | ||
nn.Upsample(scale_factor=2), | ||
ConvBNReLU(decoder_channels, decoder_channels), | ||
nn.Upsample(scale_factor=2), | ||
ConvBNReLU(decoder_channels, decoder_channels), | ||
nn.Upsample(scale_factor=2), | ||
) | ||
self.pre_conv = ConvBNReLU(encoder_channels[0], decoder_channels) | ||
self.head = nn.Sequential(ConvBNReLU(decoder_channels, decoder_channels // 2), | ||
nn.Upsample(scale_factor=2, mode='bilinear'), | ||
ConvBNReLU(decoder_channels // 2, decoder_channels // 2), | ||
nn.Upsample(scale_factor=2, mode='bilinear'), | ||
Conv(decoder_channels // 2, num_classes, kernel_size=1)) | ||
self.apply(self._init_weights) | ||
|
||
def forward(self, x0, x3): | ||
x3 = self.b3(x3) | ||
x3 = self.up_conv(x3) | ||
x = x3 + self.pre_conv(x0) | ||
x = self.head(x) | ||
return x | ||
|
||
def _init_weights(self, m): | ||
if isinstance(m, nn.Linear): | ||
trunc_normal_(m.weight, std=0.02) | ||
if isinstance(m, nn.Linear) and m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.LayerNorm): | ||
nn.init.constant_(m.bias, 0) | ||
nn.init.constant_(m.weight, 1.0) | ||
elif isinstance(m, nn.Conv2d): | ||
trunc_normal_(m.weight, std=0.02) | ||
if isinstance(m, nn.Conv2d) and m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.bias, 0) | ||
nn.init.constant_(m.weight, 1.0) | ||
|
||
|
||
class EfficientPyramidMamba(nn.Module): | ||
def __init__(self, | ||
backbone_name='swsl_resnet18', | ||
pretrained=True, | ||
num_classes=6, | ||
decoder_channels=128, | ||
last_feat_size=16 # last_feat_size=input_img_size // 32 | ||
): | ||
super().__init__() | ||
|
||
self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32, | ||
out_indices=(1, 4), pretrained=pretrained) | ||
encoder_channels = self.backbone.feature_info.channels() | ||
self.decoder = Decoder(encoder_channels=encoder_channels, decoder_channels=decoder_channels, num_classes=num_classes, last_feat_size=last_feat_size) | ||
|
||
def forward(self, x): | ||
x0, x3 = self.backbone(x) | ||
x = self.decoder(x0, x3) | ||
|
||
return x | ||
|
||
|
||
class PyramidMamba(nn.Module): | ||
def __init__(self, | ||
backbone_name='swin_base_patch4_window12_384.ms_in22k_ft_in1k', | ||
pretrained=True, | ||
num_classes=6, | ||
decoder_channels=128, | ||
last_feat_size=32, | ||
img_size=1024 | ||
): | ||
super().__init__() | ||
|
||
self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32, img_size=img_size, | ||
out_indices=(-4, -1), pretrained=pretrained) | ||
|
||
encoder_channels = self.backbone.feature_info.channels() | ||
self.decoder = Decoder(encoder_channels=encoder_channels, decoder_channels=decoder_channels, num_classes=num_classes, last_feat_size=last_feat_size) | ||
|
||
def forward(self, x): | ||
x0, x3 = self.backbone(x) | ||
x0 = x0.permute(0, 3, 1, 2) | ||
x3 = x3.permute(0, 3, 1, 2) | ||
x = self.decoder(x0, x3) | ||
|
||
return x | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
timm | ||
lightning | ||
timm==0.9.16 | ||
lightning==2.0.0 | ||
pytorch-lightning==2.3.0 | ||
albumentations | ||
ttach | ||
numpy | ||
|