Skip to content

Commit

Permalink
Add plainvit_tiny model
Browse files Browse the repository at this point in the history
  • Loading branch information
qinliuliuqin committed Jan 25, 2023
1 parent c1d9823 commit c8343aa
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 0 deletions.
7 changes: 7 additions & 0 deletions isegm/model/modeling/models_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ def forward(self, x):
x = self.head(x)
return x


def vit_tiny_patch16(**kwargs):
model = VisionTransformer(
patch_size=(16, 16), embed_dim=160, depth=8, num_heads=4, mlp_ratio=4, qkv_bias=True, **kwargs)
return model


def vit_base_patch16(**kwargs):
model = VisionTransformer(
patch_size=(16, 16), embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
Expand Down
128 changes: 128 additions & 0 deletions models/iter_mask/plainvit_tiny448_sbd_itermask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from isegm.utils.exp_imports.default import *
from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss

MODEL_NAME = 'sbd_plainvit_base448'


def main(cfg):
model, model_cfg = init_model(cfg)
train(model, cfg, model_cfg)


def init_model(cfg):
model_cfg = edict()
model_cfg.crop_size = (448, 448)
model_cfg.num_max_points = 24

backbone_params = dict(
img_size=model_cfg.crop_size,
patch_size=(16,16),
in_chans=3,
embed_dim=160,
depth=8,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
)

neck_params = dict(
in_dim = 160,
out_dims = [96, 192, 288, 384],
)

head_params = dict(
in_channels=[96, 192, 288, 384],
in_index=[0, 1, 2, 3],
dropout_ratio=0.1,
num_classes=1,
loss_decode=CrossEntropyLoss(),
align_corners=False,
upsample=cfg.upsample,
channels=128
)

model = PlainVitModel(
use_disks=True,
norm_radius=5,
with_prev_mask=True,
backbone_params=backbone_params,
neck_params=neck_params,
head_params=head_params,
random_split=cfg.random_split,
)

# model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_BASE)
model.to(cfg.device)

return model, model_cfg


def train(model, cfg, model_cfg):
cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size
cfg.val_batch_size = cfg.batch_size
crop_size = model_cfg.crop_size

loss_cfg = edict()
loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2)
loss_cfg.instance_loss_weight = 1.0

train_augmentator = Compose([
UniformRandomResize(scale_range=(0.75, 1.25)),
Flip(),
RandomRotate90(),
ShiftScaleRotate(shift_limit=0.03, scale_limit=0,
rotate_limit=(-3, 3), border_mode=0, p=0.75),
PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0),
RandomCrop(*crop_size),
RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75),
RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75)
], p=1.0)

val_augmentator = Compose([
UniformRandomResize(scale_range=(0.75, 1.25)),
PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0),
RandomCrop(*crop_size)
], p=1.0)

points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80,
merge_objects_prob=0.15,
max_num_merged_objects=2)

trainset = SBDDataset(
cfg.SBD_PATH,
split='train',
augmentator=train_augmentator,
min_object_area=80,
keep_background_prob=0.01,
points_sampler=points_sampler,
samples_scores_path='./assets/sbd_samples_weights.pkl',
samples_scores_gamma=1.25
)

valset = SBDDataset(
cfg.SBD_PATH,
split='val',
augmentator=val_augmentator,
min_object_area=80,
points_sampler=points_sampler,
epoch_len=500
)

optimizer_params = {
'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8
}

lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
milestones=[50, 55], gamma=0.1)
trainer = ISTrainer(model, cfg, model_cfg, loss_cfg,
trainset, valset,
optimizer='adam',
optimizer_params=optimizer_params,
layerwise_decay=cfg.layerwise_decay,
lr_scheduler=lr_scheduler,
checkpoint_interval=[(0, 20), (50, 1)],
image_dump_interval=300,
metrics=[AdaptiveIoU()],
max_interactive_points=model_cfg.num_max_points,
max_num_next_clicks=3)
trainer.run(num_epochs=55, validation=False)
173 changes: 173 additions & 0 deletions notebooks/test_simpleclick_tiny_model_size.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import sys\n",
"\n",
"sys.path.append(str(Path(sys.path[0]).parent))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from isegm.model.is_plainvit_model import PlainVitModel\n",
"from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss\n",
"\n",
"\n",
"embed_dim = 160\n",
"neck_dims = [96, 192, 288, 384]\n",
"\n",
"def params_vit_tiny_448(**kwargs):\n",
"\n",
" backbone_params = dict(img_size=(448, 448), patch_size=(16,16), in_chans=3,\n",
" embed_dim=embed_dim, depth=8, num_heads=4, mlp_ratio=4, qkv_bias=True,)\n",
"\n",
" neck_params = dict(in_dim=embed_dim, out_dims=neck_dims,)\n",
"\n",
" head_params = dict(in_channels=neck_dims, in_index=[0, 1, 2, 3],\n",
" dropout_ratio=0.1, num_classes=1, loss_decode=CrossEntropyLoss(),\n",
" align_corners=False, channels=128,)\n",
"\n",
" return backbone_params, neck_params, head_params\n",
"\n",
"\n",
"vith_backbone_params, vith_neck_params, vith_head_params = params_vit_tiny_448()\n",
"model_vitt = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, \n",
" backbone_params=vith_backbone_params, neck_params=vith_neck_params, \n",
" head_params=vith_head_params)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----\n",
"13.6 M\n",
"54.3 M\n",
"38.7 M\n",
"11.4 M\n",
"3.6 M\n"
]
}
],
"source": [
"def get_params_count(model):\n",
" return sum(param.numel() for param in model.parameters())\n",
"\n",
"for model in [model_vitt]:\n",
" print('----')\n",
" print('{:.1f} M'.format(get_params_count(model) / 1024 / 1024))\n",
" print('{:.1f} M'.format(get_params_count(model) * 4.0 / 1024 / 1024))\n",
" print('{:.1f} M'.format(get_params_count(model.backbone) * 4.0 / 1024 / 1024))\n",
" print('{:.1f} M'.format(get_params_count(model.neck) * 4.0 / 1024 / 1024))\n",
" print('{:.1f} M'.format(get_params_count(model.head) * 4.0 / 1024 / 1024))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.\n",
"[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.\n",
"[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
"[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n",
"10.52171G 3.71619M\n",
"[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.\n",
"[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.\n",
"[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
"[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n",
"10.52171G 3.71619M\n"
]
}
],
"source": [
"import torch\n",
"from thop import profile\n",
"from thop import clever_format\n",
"\n",
"\n",
"# ViT-B, ViT-L, ViT-H 448\n",
"input = torch.randn(1, 4, 448, 448)\n",
"point = torch.randn(1, 2, 3)\n",
"\n",
"for model in [model_vitt]:\n",
" model.eval()\n",
"\n",
" macs, params = profile(model, inputs=(input, point))\n",
" gflops, params = clever_format([macs*2, params], \"%.5f\")\n",
"\n",
" print(gflops, params)\n",
"\n",
"\n",
"vitb_backbone_params, vitb_neck_params, vitb_head_params = params_vit_tiny_448()\n",
"model_vitb = PlainVitModel(use_disks=True, norm_radius=5, with_prev_mask=True, \n",
" backbone_params=vitb_backbone_params, neck_params=vitb_neck_params, \n",
" head_params=vitb_head_params)\n",
"\n",
"# ViT-B-224\n",
"input = torch.randn(1, 4, 448, 448)\n",
"point = torch.randn(1, 2, 3)\n",
"\n",
"for model in [model_vitb]:\n",
" model.eval()\n",
"\n",
" macs, params = profile(model, inputs=(input, point))\n",
" gflops, params = clever_format([macs*2, params], \"%.5f\")\n",
"\n",
" print(gflops, params)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "e89c9d8a920cc6bbfac98998034c8a15ed2e75dae1678af787f4c4e38e5be518"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c8343aa

Please sign in to comment.