Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Implement RegNetZ Model (#713)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #713

Add implementation of RegNetZ models, as per https://arxiv.org/abs/2103.06877

RegNetZ models are trained with a convolutional fully connected head

Differential Revision: D27028613

fbshipit-source-id: 4208bb68d4551e32dbaaee38245fbf948eea5c86
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 16, 2021
1 parent 3d4c25e commit 99a48a2
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 28 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

## What's New:

- March 2021: Added [RegNetZ models](https://arxiv.org/abs/2103.06877)
- November 2020: [Vision Transformers](https://openreview.net/forum?id=YicbFdNTTy) now available, with training [recipes](https://github.com/facebookresearch/ClassyVision/tree/master/examples/vit)!

<details>
Expand Down Expand Up @@ -53,16 +54,16 @@
- Add support to convert any `PyTorch` model to a `ClassyModel` with the ability to attach heads to it ([#461](https://github.com/facebookresearch/ClassyVision/pull/461))
- Added a corresponding [tutorial](https://classyvision.ai/tutorials/classy_model) on `ClassyModel` and `ClassyHeads` ([#485](https://github.com/facebookresearch/ClassyVision/pull/485))
- [Squeeze and Excitation](https://arxiv.org/pdf/1709.01507.pdf) support for `ResNe(X)t` and `DenseNet` models ([#426](https://github.com/facebookresearch/ClassyVision/pull/426), [#427](https://github.com/facebookresearch/ClassyVision/pull/427))
- Made `ClassyHook`s registrable ([#401](https://github.com/facebookresearch/ClassyVision/pull/401)) and configurable ([#402](https://github.com/facebookresearch/ClassyVision/pull/402))
- Made `ClassyHook`s registrable ([#401](https://github.com/facebookresearch/ClassyVision/pull/401)) and configurable ([#402](https://github.com/facebookresearch/ClassyVision/pull/402))
- Migrated to [`TorchElastic v0.2.0`](https://pytorch.org/elastic/master/examples.html#classy-vision) ([#464](https://github.com/facebookresearch/ClassyVision/pull/464))
- Add `SyncBatchNorm` support ([#423](https://github.com/facebookresearch/ClassyVision/pull/423))
- Implement [`mixup`](https://arxiv.org/abs/1710.09412) train augmentation ([#469](https://github.com/facebookresearch/ClassyVision/pull/469))
- Implement [`mixup`](https://arxiv.org/abs/1710.09412) train augmentation ([#469](https://github.com/facebookresearch/ClassyVision/pull/469))
- Support [`LARC`](https://arxiv.org/abs/1708.03888) for SGD optimizer ([#408](https://github.com/facebookresearch/ClassyVision/pull/408))
- Added convenience wrappers for `Iterable` datasets ([#455](https://github.com/facebookresearch/ClassyVision/pull/455))
- `Tensorboard` improvements
- Plot histograms of model weights to Tensorboard ([#432](https://github.com/facebookresearch/ClassyVision/pull/432))
- Reduce data logged to tensorboard ([#436](https://github.com/facebookresearch/ClassyVision/pull/436))
- Invalid (`NaN` / `Inf`) loss detection
- Invalid (`NaN` / `Inf`) loss detection
- Revamped logging ([#478](https://github.com/facebookresearch/ClassyVision/pull/478))
- Add `bn_weight_decay` configuration option for `ResNe(X)t` models
- Support specifying `update_interval` to Parameter Schedulers ([#418](https://github.com/facebookresearch/ClassyVision/pull/418))
Expand Down Expand Up @@ -100,7 +101,7 @@
* Integration with PyTorch Hub. AI researchers and engineers can download and fine-tune the best publically available ImageNet models with just a few lines of code.
* Elastic training. We have also added experimental integration with [PyTorch Elastic](https://github.com/pytorch/elastic), which allows distributed training jobs to adjust as available resources in the cluster changes. It also makes distributed training robust to transient hardware failures.

Classy Vision is beta software. The project is under active development and our APIs are subject to change in future releases.
Classy Vision is beta software. The project is under active development and our APIs are subject to change in future releases.

## Installation

Expand Down
2 changes: 1 addition & 1 deletion classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def flops(self, x):
flops = count1 + count2

# non-linearities:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
flops = x.numel()

# 2D pooling layers:
Expand Down
131 changes: 108 additions & 23 deletions classy_vision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BlockType(Enum):
VANILLA_BLOCK = auto()
RES_BASIC_BLOCK = auto()
RES_BOTTLENECK_BLOCK = auto()
RES_BOTTLENECK_LINEAR_BLOCK = auto()


# The different possible Stems
Expand Down Expand Up @@ -206,8 +207,8 @@ def __init__(
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
bot_mul: float,
group_width: int,
bot_mul: float,
se_ratio: Optional[float],
):
super().__init__()
Expand Down Expand Up @@ -253,8 +254,8 @@ def __init__(
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
bot_mul: float = 1.0,
group_width: int = 1,
bot_mul: float = 1.0,
se_ratio: Optional[float] = None,
):
super().__init__()
Expand All @@ -273,8 +274,8 @@ def __init__(
bn_epsilon,
bn_momentum,
activation,
bot_mul,
group_width,
bot_mul,
se_ratio,
)
self.activation = activation
Expand All @@ -291,6 +292,41 @@ def forward(self, x, *args):
return self.activation(x)


class ResBottleneckLinearBlock(nn.Module):
"""Residual linear bottleneck block: x + F(x), F = bottleneck transform."""

def __init__(
self,
width_in: int,
width_out: int,
stride: int,
bn_epsilon: float,
bn_momentum: float,
activation: nn.Module,
group_width: int = 1,
bot_mul: float = 4.0,
se_ratio: Optional[float] = None,
):
super().__init__()
self.has_skip = (width_in == width_out) and (stride == 1)
self.f = BottleneckTransform(
width_in,
width_out,
stride,
bn_epsilon,
bn_momentum,
activation,
group_width,
bot_mul,
se_ratio,
)

self.depth = self.f.depth

def forward(self, x):
return x + self.f(x) if self.has_skip else self.f(x)


class AnyStage(nn.Sequential):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""

Expand All @@ -302,8 +338,8 @@ def __init__(
depth: int,
block_constructor: nn.Module,
activation: nn.Module,
bot_mul: float,
group_width: int,
bot_mul: float,
params: "RegNetParams",
stage_index: int = 0,
):
Expand All @@ -318,8 +354,8 @@ def __init__(
params.bn_epsilon,
params.bn_momentum,
activation,
bot_mul,
group_width,
bot_mul,
params.se_ratio,
)

Expand Down Expand Up @@ -354,10 +390,11 @@ def __init__(
w_a: float,
w_m: float,
group_w: int,
stem_type: StemType = "SIMPLE_STEM_IN",
bot_mul: float = 1.0,
stem_type: StemType = StemType.SIMPLE_STEM_IN,
stem_width: int = 32,
block_type: BlockType = "RES_BOTTLENECK_BLOCK",
activation_type: ActivationType = "RELU",
block_type: BlockType = BlockType.RES_BOTTLENECK_BLOCK,
activation: ActivationType = ActivationType.RELU,
use_se: bool = True,
se_ratio: float = 0.25,
bn_epsilon: float = 1e-05,
Expand All @@ -371,9 +408,10 @@ def __init__(
self.w_a = w_a
self.w_m = w_m
self.group_w = group_w
self.stem_type = StemType[stem_type]
self.block_type = BlockType[block_type]
self.activation_type = ActivationType[activation_type]
self.bot_mul = bot_mul
self.stem_type = stem_type
self.block_type = block_type
self.activation = activation
self.stem_width = stem_width
self.use_se = use_se
self.se_ratio = se_ratio if use_se else None
Expand Down Expand Up @@ -403,7 +441,6 @@ def get_expanded_params(self):

QUANT = 8
STRIDE = 2
BOT_MUL = 1.0

# Compute the block widths. Each stage has one unique block width
widths_cont = np.arange(self.depth) * self.w_a + self.w_0
Expand All @@ -428,26 +465,31 @@ def get_expanded_params(self):
stage_depths = np.diff([d for d, t in enumerate(splits) if t]).tolist()

strides = [STRIDE] * num_stages
bot_muls = [BOT_MUL] * num_stages
bot_muls = [self.bot_mul] * num_stages
group_widths = [self.group_w] * num_stages

# Adjust the compatibility of stage widths and group widths
stage_widths, group_widths = _adjust_widths_groups_compatibilty(
stage_widths, bot_muls, group_widths
)

return zip(stage_widths, strides, stage_depths, bot_muls, group_widths)
return zip(stage_widths, strides, stage_depths, group_widths, bot_muls)


@register_model("regnet")
class RegNet(ClassyModel):
"""Implementation of RegNet, a particular form of AnyNets
See https://arxiv.org/abs/2003.13678v1"""
"""Implementation of RegNet, a particular form of AnyNets.
See https://arxiv.org/abs/2003.13678 for introduction to RegNets, and details about
RegNetX and RegNetY models.
See https://arxiv.org/abs/2103.06877 for details about RegNetZ models.
"""

def __init__(self, params: RegNetParams):
super().__init__()

if params.activation_type == ActivationType.SILU and get_torch_version() < [
if params.activation == ActivationType.SILU and get_torch_version() < [
1,
7,
]:
Expand All @@ -456,7 +498,7 @@ def __init__(self, params: RegNetParams):
activation = {
ActivationType.RELU: nn.ReLU(params.relu_in_place),
ActivationType.SILU: nn.SiLU(),
}[params.activation_type]
}[params.activation]

# Ad hoc stem
self.stem = {
Expand All @@ -476,14 +518,15 @@ def __init__(self, params: RegNetParams):
BlockType.VANILLA_BLOCK: VanillaBlock,
BlockType.RES_BASIC_BLOCK: ResBasicBlock,
BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
}[params.block_type]

current_width = params.stem_width

self.trunk_depth = 0

blocks = []
for i, (width_out, stride, depth, bot_mul, group_width) in enumerate(
for i, (width_out, stride, depth, group_width, bot_mul) in enumerate(
params.get_expanded_params()
):
blocks.append(
Expand All @@ -496,8 +539,8 @@ def __init__(self, params: RegNetParams):
depth,
block_fun,
activation,
bot_mul,
group_width,
bot_mul,
params,
stage_index=i + 1,
),
Expand Down Expand Up @@ -531,10 +574,13 @@ def from_config(cls, config: Dict[str, Any]) -> "RegNet":
w_a=config["w_a"],
w_m=config["w_m"],
group_w=config["group_width"],
stem_type=config.get("stem_type", "simple_stem_in").upper(),
bot_mul=config.get("bot_mul", 1.0),
stem_type=StemType[config.get("stem_type", "simple_stem_in").upper()],
stem_width=config.get("stem_width", 32),
block_type=config.get("block_type", "res_bottleneck_block").upper(),
activation_type=config.get("activation_type", "relu").upper(),
block_type=BlockType[
config.get("block_type", "res_bottleneck_block").upper()
],
activation=ActivationType[config.get("activation", "relu").upper()],
use_se=config.get("use_se", True),
se_ratio=config.get("se_ratio", 0.25),
bn_epsilon=config.get("bn_epsilon", 1e-05),
Expand Down Expand Up @@ -753,6 +799,45 @@ def __init__(self, **kwargs):
)


# note that RegNetZ models are trained with a convolutional head, i.e. the
# fully_connected ClassyHead with conv_planes > 0.
@register_model("regnet_z_500mf")
class RegNetZ500mf(_RegNet):
def __init__(self, **kwargs):
super().__init__(
RegNetParams(
depth=21,
w_0=16,
w_a=10.7,
w_m=2.51,
group_w=4,
bot_mul=4.0,
block_type=BlockType.RES_BOTTLENECK_LINEAR_BLOCK,
activation=ActivationType.SILU,
**kwargs,
)
)


# this is supposed to be trained with a resolution of 256x256
@register_model("regnet_z_4gf")
class RegNetZ4gf(_RegNet):
def __init__(self, **kwargs):
super().__init__(
RegNetParams(
depth=28,
w_0=48,
w_a=14.5,
w_m=2.226,
group_w=8,
bot_mul=4.0,
block_type=BlockType.RES_BOTTLENECK_LINEAR_BLOCK,
activation=ActivationType.SILU,
**kwargs,
)
)


# -----------------------------------------------------------------------------------
# The following models were not part of the original publication,
# (https://arxiv.org/abs/2003.13678v1), but are larger versions of the
Expand Down
15 changes: 15 additions & 0 deletions test/models_regnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@
"group_width": 56,
},
),
(
{
# RegNetZ
"name": "regnet",
"block_type": "res_bottleneck_linear_block",
"depth": 21,
"w_0": 16,
"w_a": 10.7,
"w_m": 2.51,
"group_width": 4,
"bot_mul": 4.0,
"activation": "silu",
},
),
]


Expand All @@ -128,6 +142,7 @@
"regnet_x_8gf",
"regnet_x_16gf",
"regnet_x_32gf",
"regnet_z_500mf",
]

REGNET_TEST_PRESETS = [({"name": n},) for n in REGNET_TEST_PRESET_NAMES]
Expand Down

0 comments on commit 99a48a2

Please sign in to comment.