Skip to content

Commit

Permalink
update pyramidmamba
Browse files Browse the repository at this point in the history
  • Loading branch information
WangLibo1995 committed Jul 22, 2024
1 parent 0e7ed6f commit 2e9eaca
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 2 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/efficient-hybrid-transformer-learning-global/semantic-segmentation-on-uavid)](https://paperswithcode.com/sota/semantic-segmentation-on-uavid?p=efficient-hybrid-transformer-learning-global)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/efficient-hybrid-transformer-learning-global/semantic-segmentation-on-loveda)](https://paperswithcode.com/sota/semantic-segmentation-on-loveda?p=efficient-hybrid-transformer-learning-global)

- The code of [PyramidMamba](./geoseg/models/PyramidMamba.py) is released.
- I have updated this repo to pytorch 2.0 and pytorch-lightning 2.0, support multi-gpu training, etc.
- Pretrained Weights of backbones can be access from [Google Drive](https://drive.google.com/drive/folders/1ELpFKONJZbXmwB5WCXG7w42eHtrXzyPn?usp=sharing)
- [UNetFormer](https://www.sciencedirect.com/science/article/pii/S0924271622001654) (accepted by ISPRS, [PDF](https://www.researchgate.net/profile/Libo-Wang-17/publication/361736439_UNetFormer_A_UNet-like_transformer_for_efficient_semantic_segmentation_of_remote_sensing_urban_scene_imagery/links/62c2a1ed1cbf3a1d12ac1c87/UNetFormer-A-UNet-like-transformer-for-efficient-semantic-segmentation-of-remote-sensing-urban-scene-imagery.pdf)) and **UAVid dataset** are supported.
Expand Down Expand Up @@ -37,13 +38,18 @@ which mainly focuses on developing advanced Vision Transformers for remote sensi
- [ISPRS Vaihingen and Potsdam](https://www.isprs.org/education/benchmarks/UrbanSemLab/default.aspx)
- [UAVid](https://uavid.nl/)
- [LoveDA](https://codalab.lisn.upsaclay.fr/competitions/421)
- [OpenEarthMap](https://open-earth-map.org/)
- More datasets will be supported in the future.

- Multi-scale Training and Testing
- Inference on Huge Remote Sensing Images

## Supported Networks

- Mamba

- [PyramidMamba](https://arxiv.org/abs/2406.10828)

- Vision Transformer

- [UNetFormer](https://authors.elsevier.com/a/1fIji3I9x1j9Fs)
Expand Down Expand Up @@ -111,6 +117,12 @@ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.o
pip install -r GeoSeg/requirements.txt
```

Install Mamba
```
pip install causal-conv1d>=1.4.0
pip install mamba-ssm
```

## Pretrained Weights of Backbones

[Baidu Disk](https://pan.baidu.com/s/1foJkxeUZwVi5SnKNpn6hfg) : 1234
Expand Down
239 changes: 239 additions & 0 deletions geoseg/models/PyramidMamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import timm
from mamba_ssm import Mamba


class ConvBNReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels),
nn.ReLU6()
)


class ConvBN(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
super(ConvBN, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels)
)


class Conv(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
super(Conv, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
)


class SeparableConvBNReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
norm_layer=nn.BatchNorm2d):
super(SeparableConvBNReLU, self).__init__(
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
groups=in_channels, bias=False),
norm_layer(in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.ReLU6()
)


class SeparableConvBN(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
norm_layer=nn.BatchNorm2d):
super(SeparableConvBN, self).__init__(
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
groups=in_channels, bias=False),
norm_layer(out_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
)


class SeparableConv(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
super(SeparableConv, self).__init__(
nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
groups=in_channels, bias=False),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
)


class MambaLayer(nn.Module):
def __init__(self, in_chs=512, dim=128, d_state=16, d_conv=4, expand=2, last_feat_size=16):
super().__init__()
pool_scales = self.generate_arithmetic_sequence(1, last_feat_size, last_feat_size // 8)
self.pool_len = len(pool_scales)
self.pool_layers = nn.ModuleList()
self.pool_layers.append(nn.Sequential(
ConvBNReLU(in_chs, dim, kernel_size=1),
nn.AdaptiveAvgPool2d(1)
))
for pool_scale in pool_scales[1:]:
self.pool_layers.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvBNReLU(in_chs, dim, kernel_size=1)
))
self.mamba = Mamba(
d_model=dim*self.pool_len+in_chs, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand # Block expansion factor
)

def forward(self, x): # B, C, H, W
res = x
B, C, H, W = res.shape
ppm_out = [res]
for p in self.pool_layers:
pool_out = p(x)
pool_out = F.interpolate(pool_out, (H, W), mode='bilinear', align_corners=False)
ppm_out.append(pool_out)
x = torch.cat(ppm_out, dim=1)
_, chs, _, _ = x.shape
x = rearrange(x, 'b c h w -> b (h w) c', b=B, c=chs, h=H, w=W)
x = self.mamba(x)
x = x.transpose(2, 1).view(B, chs, H, W)
return x

def generate_arithmetic_sequence(self, start, stop, step):
sequence = []
for i in range(start, stop, step):
sequence.append(i)
return sequence


class ConvFFN(nn.Module):
def __init__(self, in_ch=128, hidden_ch=512, out_ch=128, drop=0.):
super(ConvFFN, self).__init__()
self.conv = ConvBNReLU(in_ch, in_ch, kernel_size=3)
self.fc1 = Conv(in_ch, hidden_ch, kernel_size=1)
self.act = nn.GELU()
self.fc2 = Conv(hidden_ch, out_ch, kernel_size=1)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.conv(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)

return x


class Block(nn.Module):
def __init__(self, in_chs=512, dim=128, hidden_ch=512, out_ch=128, drop=0.1, d_state=16, d_conv=4, expand=2, last_feat_size=16):
super(Block, self).__init__()
self.mamba = MambaLayer(in_chs=in_chs, dim=dim, d_state=d_state, d_conv=d_conv, expand=expand, last_feat_size=last_feat_size)
self.conv_ffn = ConvFFN(in_ch=dim*self.mamba.pool_len+in_chs, hidden_ch=hidden_ch, out_ch=out_ch, drop=drop)

def forward(self, x):
x = self.mamba(x)
x = self.conv_ffn(x)

return x


class Decoder(nn.Module):
def __init__(self, encoder_channels=(64, 128, 256, 512), decoder_channels=128, num_classes=6, last_feat_size=16):
super().__init__()
self.b3 = Block(in_chs=encoder_channels[-1], dim=decoder_channels, last_feat_size=last_feat_size)
self.up_conv = nn.Sequential(ConvBNReLU(decoder_channels, decoder_channels),
nn.Upsample(scale_factor=2),
ConvBNReLU(decoder_channels, decoder_channels),
nn.Upsample(scale_factor=2),
ConvBNReLU(decoder_channels, decoder_channels),
nn.Upsample(scale_factor=2),
)
self.pre_conv = ConvBNReLU(encoder_channels[0], decoder_channels)
self.head = nn.Sequential(ConvBNReLU(decoder_channels, decoder_channels // 2),
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvBNReLU(decoder_channels // 2, decoder_channels // 2),
nn.Upsample(scale_factor=2, mode='bilinear'),
Conv(decoder_channels // 2, num_classes, kernel_size=1))
self.apply(self._init_weights)

def forward(self, x0, x3):
x3 = self.b3(x3)
x3 = self.up_conv(x3)
x = x3 + self.pre_conv(x0)
x = self.head(x)
return x

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Conv2d) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)


class EfficientPyramidMamba(nn.Module):
def __init__(self,
backbone_name='swsl_resnet18',
pretrained=True,
num_classes=6,
decoder_channels=128,
last_feat_size=16 # last_feat_size=input_img_size // 32
):
super().__init__()

self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32,
out_indices=(1, 4), pretrained=pretrained)
encoder_channels = self.backbone.feature_info.channels()
self.decoder = Decoder(encoder_channels=encoder_channels, decoder_channels=decoder_channels, num_classes=num_classes, last_feat_size=last_feat_size)

def forward(self, x):
x0, x3 = self.backbone(x)
x = self.decoder(x0, x3)

return x


class PyramidMamba(nn.Module):
def __init__(self,
backbone_name='swin_base_patch4_window12_384.ms_in22k_ft_in1k',
pretrained=True,
num_classes=6,
decoder_channels=128,
last_feat_size=32,
img_size=1024
):
super().__init__()

self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32, img_size=img_size,
out_indices=(-4, -1), pretrained=pretrained)

encoder_channels = self.backbone.feature_info.channels()
self.decoder = Decoder(encoder_channels=encoder_channels, decoder_channels=decoder_channels, num_classes=num_classes, last_feat_size=last_feat_size)

def forward(self, x):
x0, x3 = self.backbone(x)
x0 = x0.permute(0, 3, 1, 2)
x3 = x3.permute(0, 3, 1, 2)
x = self.decoder(x0, x3)

return x

5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
timm
lightning
timm==0.9.16
lightning==2.0.0
pytorch-lightning==2.3.0
albumentations
ttach
numpy
Expand Down

0 comments on commit 2e9eaca

Please sign in to comment.