Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Implement RegNetZ Model
Browse files Browse the repository at this point in the history
Summary:
Add implementation of RegNetZ models, as per https://arxiv.org/abs/2103.06877

RegNetZ models are trained with a convolutional fully connected head

Differential Revision: D27028613

fbshipit-source-id: 1d7f4f4e182c618d27c502ba6e84e3057a480e6a
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 13, 2021
1 parent e090bfd commit f0a1fbf
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 17 deletions.
78 changes: 61 additions & 17 deletions classy_vision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BlockType(Enum):
VANILLA_BLOCK = auto()
RES_BASIC_BLOCK = auto()
RES_BOTTLENECK_BLOCK = auto()
RES_BOTTLENECK_LINEAR_BLOCK = auto()


# The different possible Stems
Expand Down Expand Up @@ -206,8 +207,8 @@ def __init__(
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
bot_mul: float,
group_width: int,
bot_mul: float,
se_ratio: Optional[float],
):
super().__init__()
Expand Down Expand Up @@ -253,8 +254,8 @@ def __init__(
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
bot_mul: float = 1.0,
group_width: int = 1,
bot_mul: float = 1.0,
se_ratio: Optional[float] = None,
):
super().__init__()
Expand All @@ -273,8 +274,8 @@ def __init__(
bn_epsilon,
bn_momentum,
activation,
bot_mul,
group_width,
bot_mul,
se_ratio,
)
self.activation = activation
Expand All @@ -291,6 +292,41 @@ def forward(self, x, *args):
return self.activation(x)


class ResBottleneckLinearBlock(nn.Module):
"""Residual linear bottleneck block: x + F(x), F = bottleneck transform."""

def __init__(
self,
width_in: int,
width_out: int,
stride: int,
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
group_width: int = 1,
bot_mul: float = 4.0,
se_ratio: Optional[float] = None,
):
super().__init__()
self.has_skip = (width_in == width_out) and (stride == 1)
self.f = BottleneckTransform(
width_in,
width_out,
stride,
bn_epsilon,
bn_momentum,
activation,
group_width,
bot_mul,
se_ratio,
)

self.depth = self.f.depth

def forward(self, x):
return x + self.f(x) if self.has_skip else self.f(x)


class AnyStage(nn.Sequential):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""

Expand All @@ -302,8 +338,8 @@ def __init__(
depth: int,
block_constructor: nn.Module,
activation: nn.Module,
bot_mul: float,
group_width: int,
bot_mul: float,
params: "RegNetParams",
stage_index: int = 0,
):
Expand All @@ -318,8 +354,8 @@ def __init__(
params.bn_epsilon,
params.bn_momentum,
activation,
bot_mul,
group_width,
bot_mul,
params.se_ratio,
)

Expand Down Expand Up @@ -354,10 +390,11 @@ def __init__(
w_a: float,
w_m: float,
group_w: int,
bot_mul: float = 1.0,
stem_type: StemType = "SIMPLE_STEM_IN",
stem_width: int = 32,
block_type: BlockType = "RES_BOTTLENECK_BLOCK",
activation_type: ActivationType = "RELU",
activation: ActivationType = "RELU",
use_se: bool = True,
se_ratio: float = 0.25,
bn_epsilon: float = 1e-05,
Expand All @@ -371,9 +408,10 @@ def __init__(
self.w_a = w_a
self.w_m = w_m
self.group_w = group_w
self.bot_mul = bot_mul
self.stem_type = StemType[stem_type]
self.block_type = BlockType[block_type]
self.activation_type = ActivationType[activation_type]
self.activation = ActivationType[activation]
self.stem_width = stem_width
self.use_se = use_se
self.se_ratio = se_ratio if use_se else None
Expand Down Expand Up @@ -403,7 +441,6 @@ def get_expanded_params(self):

QUANT = 8
STRIDE = 2
BOT_MUL = 1.0

# Compute the block widths. Each stage has one unique block width
widths_cont = np.arange(self.depth) * self.w_a + self.w_0
Expand All @@ -428,26 +465,31 @@ def get_expanded_params(self):
stage_depths = np.diff([d for d, t in enumerate(splits) if t]).tolist()

strides = [STRIDE] * num_stages
bot_muls = [BOT_MUL] * num_stages
bot_muls = [self.bot_mul] * num_stages
group_widths = [self.group_w] * num_stages

# Adjust the compatibility of stage widths and group widths
stage_widths, group_widths = _adjust_widths_groups_compatibilty(
stage_widths, bot_muls, group_widths
)

return zip(stage_widths, strides, stage_depths, bot_muls, group_widths)
return zip(stage_widths, strides, stage_depths, group_widths, bot_muls)


@register_model("regnet")
class RegNet(ClassyModel):
"""Implementation of RegNet, a particular form of AnyNets
See https://arxiv.org/abs/2003.13678v1"""
"""Implementation of RegNet, a particular form of AnyNets.
See https://arxiv.org/abs/2003.13678 for introduction to RegNets, and details about
RegNetX and RegNetY models.
See https://arxiv.org/abs/2103.06877 for details about RegNetZ models.
"""

def __init__(self, params: RegNetParams):
super().__init__()

if params.activation_type == ActivationType.SILU and get_torch_version() < [
if params.activation == ActivationType.SILU and get_torch_version() < [
1,
7,
]:
Expand All @@ -456,7 +498,7 @@ def __init__(self, params: RegNetParams):
activation = {
ActivationType.RELU: nn.ReLU(params.relu_in_place),
ActivationType.SILU: nn.SiLU(),
}[params.activation_type]
}[params.activation]

# Ad hoc stem
self.stem = {
Expand All @@ -476,14 +518,15 @@ def __init__(self, params: RegNetParams):
BlockType.VANILLA_BLOCK: VanillaBlock,
BlockType.RES_BASIC_BLOCK: ResBasicBlock,
BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
}[params.block_type]

current_width = params.stem_width

self.trunk_depth = 0

blocks = []
for i, (width_out, stride, depth, bot_mul, group_width) in enumerate(
for i, (width_out, stride, depth, group_width, bot_mul) in enumerate(
params.get_expanded_params()
):
blocks.append(
Expand All @@ -496,8 +539,8 @@ def __init__(self, params: RegNetParams):
depth,
block_fun,
activation,
bot_mul,
group_width,
bot_mul,
params,
stage_index=i + 1,
),
Expand Down Expand Up @@ -531,10 +574,11 @@ def from_config(cls, config: Dict[str, Any]) -> "RegNet":
w_a=config["w_a"],
w_m=config["w_m"],
group_w=config["group_width"],
bot_mul=config.get("bot_mul", 1.0),
stem_type=config.get("stem_type", "simple_stem_in").upper(),
stem_width=config.get("stem_width", 32),
block_type=config.get("block_type", "res_bottleneck_block").upper(),
activation_type=config.get("activation_type", "relu").upper(),
activation=config.get("activation", "relu").upper(),
use_se=config.get("use_se", True),
se_ratio=config.get("se_ratio", 0.25),
bn_epsilon=config.get("bn_epsilon", 1e-05),
Expand Down
13 changes: 13 additions & 0 deletions test/models_regnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@
"group_width": 56,
},
),
(
{
# RegNetZ
"name": "regnet",
"block_type": "res_bottleneck_linear_block",
"depth": 21,
"w_0": 16,
"w_a": 10.7,
"w_m": 2.51,
"group_width": 4,
"bot_mul": 4.0,
},
),
]


Expand Down

0 comments on commit f0a1fbf

Please sign in to comment.