Skip to content

Commit

Permalink
Implement RegNetZ Model (facebookresearch#713)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#713

Add implementation of RegNetZ models, as per https://arxiv.org/abs/2103.06877

RegNetZ models are trained with a convolutional fully connected head

Differential Revision: D27028613

fbshipit-source-id: 8418e1ec7155f1cbfd348a91c94fd45eef61b5f9
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 22, 2021
1 parent c5a7c38 commit 6761129
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

## What's New:

- March 2021: Added [RegNetZ models](https://arxiv.org/abs/2103.06877)
- November 2020: [Vision Transformers](https://openreview.net/forum?id=YicbFdNTTy) now available, with training [recipes](https://github.com/facebookresearch/ClassyVision/tree/master/examples/vit)!

<details>
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def flops(self, x):
flops = count1 + count2

# non-linearities:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
flops = x.numel()

# 2D pooling layers:
Expand Down
127 changes: 106 additions & 21 deletions classy_vision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BlockType(Enum):
VANILLA_BLOCK = auto()
RES_BASIC_BLOCK = auto()
RES_BOTTLENECK_BLOCK = auto()
RES_BOTTLENECK_LINEAR_BLOCK = auto()


# The different possible Stems
Expand Down Expand Up @@ -206,8 +207,8 @@ def __init__(
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
bot_mul: float,
group_width: int,
bot_mul: float,
se_ratio: Optional[float],
):
super().__init__()
Expand Down Expand Up @@ -253,8 +254,8 @@ def __init__(
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
bot_mul: float = 1.0,
group_width: int = 1,
bot_mul: float = 1.0,
se_ratio: Optional[float] = None,
):
super().__init__()
Expand All @@ -273,8 +274,8 @@ def __init__(
bn_epsilon,
bn_momentum,
activation,
bot_mul,
group_width,
bot_mul,
se_ratio,
)
self.activation = activation
Expand All @@ -291,6 +292,41 @@ def forward(self, x, *args):
return self.activation(x)


class ResBottleneckLinearBlock(nn.Module):
"""Residual linear bottleneck block: x + F(x), F = bottleneck transform."""

def __init__(
self,
width_in: int,
width_out: int,
stride: int,
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
group_width: int = 1,
bot_mul: float = 4.0,
se_ratio: Optional[float] = None,
):
super().__init__()
self.has_skip = (width_in == width_out) and (stride == 1)
self.f = BottleneckTransform(
width_in,
width_out,
stride,
bn_epsilon,
bn_momentum,
activation,
group_width,
bot_mul,
se_ratio,
)

self.depth = self.f.depth

def forward(self, x):
return x + self.f(x) if self.has_skip else self.f(x)


class AnyStage(nn.Sequential):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""

Expand All @@ -302,8 +338,8 @@ def __init__(
depth: int,
block_constructor: nn.Module,
activation: nn.Module,
bot_mul: float,
group_width: int,
bot_mul: float,
params: "RegNetParams",
stage_index: int = 0,
):
Expand All @@ -318,8 +354,8 @@ def __init__(
params.bn_epsilon,
params.bn_momentum,
activation,
bot_mul,
group_width,
bot_mul,
params.se_ratio,
)

Expand Down Expand Up @@ -354,10 +390,11 @@ def __init__(
w_a: float,
w_m: float,
group_w: int,
stem_type: StemType = "SIMPLE_STEM_IN",
bot_mul: float = 1.0,
stem_type: StemType = StemType.SIMPLE_STEM_IN,
stem_width: int = 32,
block_type: BlockType = "RES_BOTTLENECK_BLOCK",
activation_type: ActivationType = "RELU",
block_type: BlockType = BlockType.RES_BOTTLENECK_BLOCK,
activation: ActivationType = ActivationType.RELU,
use_se: bool = True,
se_ratio: float = 0.25,
bn_epsilon: float = 1e-05,
Expand All @@ -371,9 +408,10 @@ def __init__(
self.w_a = w_a
self.w_m = w_m
self.group_w = group_w
self.stem_type = StemType[stem_type]
self.block_type = BlockType[block_type]
self.activation_type = ActivationType[activation_type]
self.bot_mul = bot_mul
self.stem_type = stem_type
self.block_type = block_type
self.activation = activation
self.stem_width = stem_width
self.use_se = use_se
self.se_ratio = se_ratio if use_se else None
Expand Down Expand Up @@ -403,7 +441,6 @@ def get_expanded_params(self):

QUANT = 8
STRIDE = 2
BOT_MUL = 1.0

# Compute the block widths. Each stage has one unique block width
widths_cont = np.arange(self.depth) * self.w_a + self.w_0
Expand All @@ -428,21 +465,26 @@ def get_expanded_params(self):
stage_depths = np.diff([d for d, t in enumerate(splits) if t]).tolist()

strides = [STRIDE] * num_stages
bot_muls = [BOT_MUL] * num_stages
bot_muls = [self.bot_mul] * num_stages
group_widths = [self.group_w] * num_stages

# Adjust the compatibility of stage widths and group widths
stage_widths, group_widths = _adjust_widths_groups_compatibilty(
stage_widths, bot_muls, group_widths
)

return zip(stage_widths, strides, stage_depths, bot_muls, group_widths)
return zip(stage_widths, strides, stage_depths, group_widths, bot_muls)


@register_model("regnet")
class RegNet(ClassyModel):
"""Implementation of RegNet, a particular form of AnyNets
See https://arxiv.org/abs/2003.13678v1"""
"""Implementation of RegNet, a particular form of AnyNets.
See https://arxiv.org/abs/2003.13678 for introduction to RegNets, and details about
RegNetX and RegNetY models.
See https://arxiv.org/abs/2103.06877 for details about RegNetZ models.
"""

def __init__(self, params: RegNetParams):
super().__init__()
Expand Down Expand Up @@ -474,14 +516,15 @@ def __init__(self, params: RegNetParams):
BlockType.VANILLA_BLOCK: VanillaBlock,
BlockType.RES_BASIC_BLOCK: ResBasicBlock,
BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
}[params.block_type]

current_width = params.stem_width

self.trunk_depth = 0

blocks = []
for i, (width_out, stride, depth, bot_mul, group_width) in enumerate(
for i, (width_out, stride, depth, group_width, bot_mul) in enumerate(
params.get_expanded_params()
):
blocks.append(
Expand All @@ -494,8 +537,8 @@ def __init__(self, params: RegNetParams):
depth,
block_fun,
activation,
bot_mul,
group_width,
bot_mul,
params,
stage_index=i + 1,
),
Expand Down Expand Up @@ -529,10 +572,13 @@ def from_config(cls, config: Dict[str, Any]) -> "RegNet":
w_a=config["w_a"],
w_m=config["w_m"],
group_w=config["group_width"],
stem_type=config.get("stem_type", "simple_stem_in").upper(),
bot_mul=config.get("bot_mul", 1.0),
stem_type=StemType[config.get("stem_type", "simple_stem_in").upper()],
stem_width=config.get("stem_width", 32),
block_type=config.get("block_type", "res_bottleneck_block").upper(),
activation_type=config.get("activation_type", "relu").upper(),
block_type=BlockType[
config.get("block_type", "res_bottleneck_block").upper()
],
activation=ActivationType[config.get("activation", "relu").upper()],
use_se=config.get("use_se", True),
se_ratio=config.get("se_ratio", 0.25),
bn_epsilon=config.get("bn_epsilon", 1e-05),
Expand Down Expand Up @@ -751,6 +797,45 @@ def __init__(self, **kwargs):
)


# note that RegNetZ models are trained with a convolutional head, i.e. the
# fully_connected ClassyHead with conv_planes > 0.
@register_model("regnet_z_500mf")
class RegNetZ500mf(_RegNet):
def __init__(self, **kwargs):
super().__init__(
RegNetParams(
depth=21,
w_0=16,
w_a=10.7,
w_m=2.51,
group_w=4,
bot_mul=4.0,
block_type=BlockType.RES_BOTTLENECK_LINEAR_BLOCK,
activation=ActivationType.SILU,
**kwargs,
)
)


# this is supposed to be trained with a resolution of 256x256
@register_model("regnet_z_4gf")
class RegNetZ4gf(_RegNet):
def __init__(self, **kwargs):
super().__init__(
RegNetParams(
depth=28,
w_0=48,
w_a=14.5,
w_m=2.226,
group_w=8,
bot_mul=4.0,
block_type=BlockType.RES_BOTTLENECK_LINEAR_BLOCK,
activation=ActivationType.SILU,
**kwargs,
)
)


# -----------------------------------------------------------------------------------
# The following models were not part of the original publication,
# (https://arxiv.org/abs/2003.13678v1), but are larger versions of the
Expand Down
15 changes: 15 additions & 0 deletions test/models_regnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@
"group_width": 56,
},
),
(
{
# RegNetZ
"name": "regnet",
"block_type": "res_bottleneck_linear_block",
"depth": 21,
"w_0": 16,
"w_a": 10.7,
"w_m": 2.51,
"group_width": 4,
"bot_mul": 4.0,
"activation": "silu",
},
),
]


Expand All @@ -128,6 +142,7 @@
"regnet_x_8gf",
"regnet_x_16gf",
"regnet_x_32gf",
"regnet_z_500mf",
]

REGNET_TEST_PRESETS = [({"name": n},) for n in REGNET_TEST_PRESET_NAMES]
Expand Down

0 comments on commit 6761129

Please sign in to comment.