Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Implement RegNetZ Model #713

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

## What's New:

- March 2021: Added [RegNetZ models](https://arxiv.org/abs/2103.06877)
- November 2020: [Vision Transformers](https://openreview.net/forum?id=YicbFdNTTy) now available, with training [recipes](https://github.com/facebookresearch/ClassyVision/tree/master/examples/vit)!

<details>
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def flops(self, x):
flops = count1 + count2

# non-linearities:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
flops = x.numel()

# 2D pooling layers:
Expand Down
67 changes: 58 additions & 9 deletions classy_vision/heads/fully_connected_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict
from typing import Any, Dict, Optional

import torch.nn as nn
from classy_vision.generic.util import is_pos_int
from classy_vision.generic.util import get_torch_version, is_pos_int
from classy_vision.heads import ClassyHead, register_head


RELU_IN_PLACE = True


@register_head("fully_connected")
class FullyConnectedHead(ClassyHead):
"""This head defines a 2d average pooling layer
Expand All @@ -21,27 +24,56 @@ class FullyConnectedHead(ClassyHead):
def __init__(
self,
unique_id: str,
num_classes: int,
num_classes: Optional[int],
in_plane: int,
conv_planes: Optional[int] = None,
activation: Optional[nn.Module] = None,
zero_init_bias: bool = False,
normalize_inputs: Optional[str] = None,
):
"""Constructor for FullyConnectedHead

Args:
unique_id: A unique identifier for the head. Multiple instances of
the same head might be attached to a model, and unique_id is used
to refer to them.

num_classes: Number of classes for the head. If None, then the fully
connected layer is not applied.

in_plane: Input size for the fully connected layer.
conv_planes: If specified, applies a 1x1 convolutional layer to the input
before passing it to the average pooling layer. The convolution is also
followed by a BatchNorm and an activation.
activation: The activation to be applied after the convolutional layer.
Unused if `conv_planes` is not specified.
zero_init_bias: Zero initialize the bias
normalize_inputs: If specified, normalize the inputs after performing
average pooling using the specified method. Supports "l2" normalization.
"""
super().__init__(unique_id, num_classes)
assert num_classes is None or is_pos_int(num_classes)
assert is_pos_int(in_plane)
if conv_planes is not None and activation is None:
raise TypeError("activation cannot be None if conv_planes is specified")
if normalize_inputs is not None and normalize_inputs != "l2":
raise ValueError(
f"Unsupported value for normalize_inputs: {normalize_inputs}"
)
self.conv = (
nn.Conv2d(in_plane, conv_planes, kernel_size=1, bias=False)
if conv_planes
else None
)
self.bn = nn.BatchNorm2d(conv_planes) if conv_planes else None
self.activation = activation
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = None if num_classes is None else nn.Linear(in_plane, num_classes)
self.fc = (
None
if num_classes is None
else nn.Linear(
in_plane if conv_planes is None else conv_planes, num_classes
)
)
self.normalize_inputs = normalize_inputs

if zero_init_bias:
self.fc.bias.data.zero_()
Expand All @@ -59,19 +91,36 @@ def from_config(cls, config: Dict[str, Any]) -> "FullyConnectedHead":
"""
num_classes = config.get("num_classes", None)
in_plane = config["in_plane"]
silu = None if get_torch_version() < [1, 7] else nn.SiLU()
activation = {"relu": nn.ReLU(RELU_IN_PLACE), "silu": silu}[
config.get("activation", "relu")
]
if activation is None:
raise RuntimeError("SiLU activation is only supported since PyTorch 1.7")
return cls(
config["unique_id"],
num_classes,
in_plane,
conv_planes=config.get("conv_planes", None),
activation=activation,
zero_init_bias=config.get("zero_init_bias", False),
normalize_inputs=config.get("normalize_inputs", None),
)

def forward(self, x):
# perform average pooling:
out = self.avgpool(x)
out = x
if self.conv is not None:
out = self.activation(self.bn(self.conv(x)))

out = self.avgpool(out)

# final classifier:
out = out.flatten(start_dim=1)

if self.normalize_inputs is not None:
if self.normalize_inputs == "l2":
out = nn.functional.normalize(out, p=2.0, dim=1)

if self.fc is not None:
out = self.fc(out)

return out
40 changes: 32 additions & 8 deletions classy_vision/heads/vision_transformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
from collections import OrderedDict
from typing import Optional

import torch.nn as nn
from classy_vision.heads import ClassyHead, register_head
Expand All @@ -25,12 +26,31 @@
class VisionTransformerHead(ClassyHead):
def __init__(
self,
in_plane,
num_classes,
hidden_dim=None,
unique_id: str,
in_plane: int,
num_classes: Optional[int] = None,
hidden_dim: Optional[int] = None,
normalize_inputs: Optional[str] = None,
):
super().__init__()
if hidden_dim is None:
"""
Args:
unique_id: A unique identifier for the head
in_plane: Input size for the fully connected layer
num_classes: Number of output classes for the head
hidden_dim: If not None, a hidden layer with the specific dimension is added
normalize_inputs: If specified, normalize the inputs using the specified
method. Supports "l2" normalization.
"""
super().__init__(unique_id, num_classes)

if normalize_inputs is not None and normalize_inputs != "l2":
raise ValueError(
f"Unsupported value for normalize_inputs: {normalize_inputs}"
)

if num_classes is None:
layers = []
elif hidden_dim is None:
layers = [("head", nn.Linear(in_plane, num_classes))]
else:
layers = [
Expand All @@ -39,6 +59,7 @@ def __init__(
("head", nn.Linear(hidden_dim, num_classes)),
]
self.layers = nn.Sequential(OrderedDict(layers))
self.normalize_inputs = normalize_inputs
self.init_weights()

def init_weights(self):
Expand All @@ -47,14 +68,17 @@ def init_weights(self):
self.layers.pre_logits.weight, fan_in=self.layers.pre_logits.in_features
)
nn.init.zeros_(self.layers.pre_logits.bias)
nn.init.zeros_(self.layers.head.weight)
nn.init.zeros_(self.layers.head.bias)
if hasattr(self.layers, "head"):
nn.init.zeros_(self.layers.head.weight)
nn.init.zeros_(self.layers.head.bias)

@classmethod
def from_config(cls, config):
config = copy.deepcopy(config)
config.pop("unique_id")
return cls(**config)

def forward(self, x):
if self.normalize_inputs is not None:
if self.normalize_inputs == "l2":
x = nn.functional.normalize(x, p=2.0, dim=1)
return self.layers(x)
Loading