Skip to content

Commit

Permalink
Add xtiny model for c+l
Browse files Browse the repository at this point in the history
qinliuliuqin committed Mar 22, 2023
1 parent 1679311 commit a066503
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions models/iter_mask/plainvit_xtiny448_cocolvis_itermask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from isegm.utils.exp_imports.default import *
from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss

MODEL_NAME = 'cocolvis_plainvit_xtiny448'


def main(cfg):
model, model_cfg = init_model(cfg)
train(model, cfg, model_cfg)


def init_model(cfg):
model_cfg = edict()
model_cfg.crop_size = (448, 448)
model_cfg.num_max_points = 24

backbone_params = dict(
img_size=model_cfg.crop_size,
patch_size=(16,16),
in_chans=3,
embed_dim=160,
depth=8,
num_heads=4,
mlp_ratio=4,
qkv_bias=True,
)

neck_params = dict(
in_dim = 160,
out_dims = [96, 192, 288, 384],
)

head_params = dict(
in_channels=[96, 192, 288, 384],
in_index=[0, 1, 2, 3],
dropout_ratio=0.1,
num_classes=1,
loss_decode=CrossEntropyLoss(),
align_corners=False,
upsample=cfg.upsample,
channels=128
)

model = PlainVitModel(
use_disks=True,
norm_radius=5,
with_prev_mask=True,
backbone_params=backbone_params,
neck_params=neck_params,
head_params=head_params,
random_split=cfg.random_split,
)

# model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_BASE)
model.to(cfg.device)

return model, model_cfg


def train(model, cfg, model_cfg):
cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size
cfg.val_batch_size = cfg.batch_size
crop_size = model_cfg.crop_size

loss_cfg = edict()
loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2)
loss_cfg.instance_loss_weight = 1.0

train_augmentator = Compose([
UniformRandomResize(scale_range=(0.75, 1.40)),
HorizontalFlip(),
PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0),
RandomCrop(*crop_size),
RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75),
RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75)
], p=1.0)

val_augmentator = Compose([
PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0),
RandomCrop(*crop_size)
], p=1.0)

points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80,
merge_objects_prob=0.15,
max_num_merged_objects=2)

trainset = CocoLvisDataset(
cfg.LVIS_v1_PATH,
split='train',
augmentator=train_augmentator,
min_object_area=1000,
keep_background_prob=0.05,
points_sampler=points_sampler,
epoch_len=30000,
stuff_prob=0.30
)

valset = CocoLvisDataset(
cfg.LVIS_v1_PATH,
split='val',
augmentator=val_augmentator,
min_object_area=1000,
points_sampler=points_sampler,
epoch_len=2000
)

optimizer_params = {
'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8
}

lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
milestones=[200, 230], gamma=0.1)

trainer = ISTrainer(model, cfg, model_cfg, loss_cfg,
trainset, valset,
optimizer='adam',
optimizer_params=optimizer_params,
layerwise_decay=cfg.layerwise_decay,
lr_scheduler=lr_scheduler,
checkpoint_interval=[(0, 50), (300, 5)],
image_dump_interval=1000,
metrics=[AdaptiveIoU()],
max_interactive_points=model_cfg.num_max_points,
max_num_next_clicks=3)
trainer.run(num_epochs=230, validation=False)

0 comments on commit a066503

Please sign in to comment.