Skip to content

Commit

Permalink
Fix strides in basic block (affects ResNet{18,34}).
Browse files Browse the repository at this point in the history
We (intentionally) implement ResNet V1.5 (downsamples at the second conv in
bottleneck blocks, while original paper downsamples in the first conv) and V2.

Before this change for basic blocks (as used in `ResNet{18,34}` and when passing
`ResNet(.., bottleneck=False)`) we also downsampled in the second conv. This
is a bug, for basic blocks downsampling should occur in the first conv.

This bug actually improved performance in affected ResNets (for standard
training on ImageNet this patch causes a -1.0% regression in top_1 accuracy for
ResNet18 and -0.5% for ResNet34). We are submitting it anyway to remain faithful
to the paper.

![tensorboard](https://user-images.githubusercontent.com/28017/102883943-b3371e00-4448-11eb-9ee5-93e5e9b47225.png)

Fixes #85.

PiperOrigin-RevId: 348638376
Change-Id: I8849cbf22ae587fc597805c4420043f471efa80a
  • Loading branch information
tomhennigan authored and copybara-github committed Dec 22, 2020
1 parent 3a20795 commit 300e6a4
Showing 1 changed file with 61 additions and 32 deletions.
93 changes: 61 additions & 32 deletions haiku/_src/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
conv_0 = hk.Conv2D(
output_channels=channels // channel_div,
kernel_shape=1 if bottleneck else 3,
stride=1,
stride=1 if bottleneck else stride,
with_bias=False,
padding="SAME",
name="conv_0")
Expand All @@ -79,7 +79,7 @@ def __init__(
conv_1 = hk.Conv2D(
output_channels=channels // channel_div,
kernel_shape=3,
stride=stride,
stride=stride if bottleneck else 1,
with_bias=False,
padding="SAME",
name="conv_1")
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(
conv_0 = hk.Conv2D(
output_channels=channels // channel_div,
kernel_shape=1 if bottleneck else 3,
stride=1,
stride=1 if bottleneck else stride,
with_bias=False,
padding="SAME",
name="conv_0")
Expand All @@ -159,7 +159,7 @@ def __init__(
conv_1 = hk.Conv2D(
output_channels=channels // channel_div,
kernel_shape=3,
stride=stride,
stride=stride if bottleneck else 1,
with_bias=False,
padding="SAME",
name="conv_1")
Expand Down Expand Up @@ -239,6 +239,45 @@ def check_length(length, value, name):
class ResNet(hk.Module):
"""ResNet model."""

CONFIGS = {
18: {
"blocks_per_group": (2, 2, 2, 2),
"bottleneck": False,
"channels_per_group": (64, 128, 256, 512),
"use_projection": (False, True, True, True),
},
34: {
"blocks_per_group": (3, 4, 6, 3),
"bottleneck": False,
"channels_per_group": (64, 128, 256, 512),
"use_projection": (False, True, True, True),
},
50: {
"blocks_per_group": (3, 4, 6, 3),
"bottleneck": True,
"channels_per_group": (256, 512, 1024, 2048),
"use_projection": (True, True, True, True),
},
101: {
"blocks_per_group": (3, 4, 23, 3),
"bottleneck": True,
"channels_per_group": (256, 512, 1024, 2048),
"use_projection": (True, True, True, True),
},
152: {
"blocks_per_group": (3, 8, 36, 3),
"bottleneck": True,
"channels_per_group": (256, 512, 1024, 2048),
"use_projection": (True, True, True, True),
},
200: {
"blocks_per_group": (3, 24, 36, 3),
"bottleneck": True,
"channels_per_group": (256, 512, 1024, 2048),
"use_projection": (True, True, True, True),
},
}

BlockGroup = BlockGroup # pylint: disable=invalid-name
BlockV1 = BlockV1 # pylint: disable=invalid-name
BlockV2 = BlockV2 # pylint: disable=invalid-name
Expand Down Expand Up @@ -364,15 +403,12 @@ def __init__(self,
logits_config: A dictionary of keyword arguments for the logits layer.
name: Name of the module.
"""
super().__init__(blocks_per_group=(2, 2, 2, 2),
num_classes=num_classes,
super().__init__(num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=False,
channels_per_group=(64, 128, 256, 512),
use_projection=(False, True, True, True),
logits_config=logits_config,
name=name)
name=name,
**ResNet.CONFIGS[18])


class ResNet34(ResNet):
Expand All @@ -395,15 +431,12 @@ def __init__(self,
logits_config: A dictionary of keyword arguments for the logits layer.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 4, 6, 3),
num_classes=num_classes,
super().__init__(num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=False,
channels_per_group=(64, 128, 256, 512),
use_projection=(False, True, True, True),
logits_config=logits_config,
name=name)
name=name,
**ResNet.CONFIGS[34])


class ResNet50(ResNet):
Expand All @@ -426,13 +459,12 @@ def __init__(self,
logits_config: A dictionary of keyword arguments for the logits layer.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 4, 6, 3),
num_classes=num_classes,
super().__init__(num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
logits_config=logits_config,
name=name)
name=name,
**ResNet.CONFIGS[50])


class ResNet101(ResNet):
Expand All @@ -455,13 +487,12 @@ def __init__(self,
logits_config: A dictionary of keyword arguments for the logits layer.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 4, 23, 3),
num_classes=num_classes,
super().__init__(num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
logits_config=logits_config,
name=name)
name=name,
**ResNet.CONFIGS[101])


class ResNet152(ResNet):
Expand All @@ -484,13 +515,12 @@ def __init__(self,
logits_config: A dictionary of keyword arguments for the logits layer.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 8, 36, 3),
num_classes=num_classes,
super().__init__(num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
logits_config=logits_config,
name=name)
name=name,
**ResNet.CONFIGS[152])


class ResNet200(ResNet):
Expand All @@ -513,10 +543,9 @@ def __init__(self,
logits_config: A dictionary of keyword arguments for the logits layer.
name: Name of the module.
"""
super().__init__(blocks_per_group=(3, 24, 36, 3),
num_classes=num_classes,
super().__init__(num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=True,
logits_config=logits_config,
name=name)
name=name,
**ResNet.CONFIGS[200])

0 comments on commit 300e6a4

Please sign in to comment.