Skip to content

Commit

Permalink
fixed errors in function comments
Browse files Browse the repository at this point in the history
VainF committed Jan 17, 2020
1 parent b9457dd commit 6656d16
Showing 1 changed file with 45 additions and 60 deletions.
105 changes: 45 additions & 60 deletions network/modeling.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,8 @@
from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
from .backbone import resnet
from .backbone import mobilenetv2
from torchvision.models.utils import load_state_dict_from_url

model_urls = {
'deeplabv3_resnet50_coco': None,
'deeplabv3_resnet101_coco': None,
}

def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone=True):
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):

if output_stride==8:
replace_stride_with_dilation=[False, True, True]
@@ -36,14 +30,15 @@ def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_bac
model = DeepLabV3(backbone, classifier)
return model

def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone=True):
def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
if output_stride==8:
aspp_dilate = [12, 24, 36]
else:
aspp_dilate = [6, 12, 18]

backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride)

# rename layers
backbone.low_level_features = backbone.features[0:4]
backbone.high_level_features = backbone.features[4:-1]
backbone.features = None
@@ -63,90 +58,80 @@ def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_
model = DeepLabV3(backbone, classifier)
return model

def _load_model(arch_type, backbone, pretrained, progress, num_classes, output_stride=8):
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):

if backbone=='mobilenetv2':
model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride)
model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
elif backbone.startswith('resnet'):
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride)
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
else:
raise NotImplementedError

if pretrained:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model

def deeplabv3_resnet50(pretrained=False, progress=True,
num_classes=21, **kwargs):

# Deeplab v3

def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, **kwargs)
return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)

def deeplabv3_resnet101(pretrained=False, progress=True,
num_classes=21, **kwargs):
def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, **kwargs)
return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)

def deeplabv3plus_resnet50(pretrained=False, progress=True,
num_classes=21, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
"""Constructs a DeepLabV3 model with a MobileNetv2 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'resnet50', pretrained, progress, num_classes, **kwargs)
return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)


def deeplabv3plus_resnet101(pretrained=False, progress=True,
num_classes=21, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
# Deeplab v3+

def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'resnet101', pretrained, progress, num_classes, **kwargs)
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)


def deeplabv3_mobilenet(pretrained=False, progress=True,
num_classes=21, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3+ model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3', 'mobilenetv2', pretrained, progress, num_classes, **kwargs)
return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)


def deeplabv3plus_mobilenet(pretrained=False, progress=True,
num_classes=21, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True):
"""Constructs a DeepLabV3+ model with a MobileNetv2 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of classes.
output_stride (int): output stride for deeplab.
pretrained_backbone (bool): If True, use the pretrained backbone.
"""
return _load_model('deeplabv3plus', 'mobilenetv2', pretrained, progress, num_classes, **kwargs)
return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)

0 comments on commit 6656d16

Please sign in to comment.