Skip to content

Commit

Permalink
Added annotation typing to resnet (#2863)
Browse files Browse the repository at this point in the history
* style: Added annotation typing for resnet

* fix: Fixed annotation to pass classes

* fix: Fixed annotation typing

* fix: Fixed annotation typing

* fix: Fixed annotation typing for resnet

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* style: Added black formatting on _resnet
  • Loading branch information
frgfm authored Oct 23, 2020
1 parent 65591f1 commit 2ce6b18
Showing 1 changed file with 66 additions and 30 deletions.
96 changes: 66 additions & 30 deletions torchvision/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import Tensor
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
Expand All @@ -21,22 +23,31 @@
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
expansion: int = 1

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
Expand All @@ -53,7 +64,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
identity = x

out = self.conv1(x)
Expand All @@ -79,10 +90,19 @@ class Bottleneck(nn.Module):
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
expansion: int = 4

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
Expand All @@ -98,7 +118,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
identity = x

out = self.conv1(x)
Expand All @@ -123,9 +143,17 @@ def forward(self, x):

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
Expand Down Expand Up @@ -170,11 +198,12 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
Expand All @@ -198,7 +227,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

return nn.Sequential(*layers)

def _forward_impl(self, x):
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
Expand All @@ -216,11 +245,18 @@ def _forward_impl(self, x):

return x

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
Expand All @@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
return model


def resnet18(pretrained=False, progress=True, **kwargs):
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand All @@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand All @@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand All @@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand All @@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet152(pretrained=False, progress=True, **kwargs):
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand All @@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Expand All @@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Expand All @@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
Expand All @@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
Expand Down

0 comments on commit 2ce6b18

Please sign in to comment.