Skip to content

Commit

Permalink
[models] Vit: fix intermediate size scale and unify TF to PT (#1063)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 19, 2022
1 parent 42195de commit 4e763da
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 55 deletions.
6 changes: 4 additions & 2 deletions doctr/models/classification/vit/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class VisionTransformer(nn.Sequential):
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
Expand All @@ -74,6 +75,7 @@ def __init__(
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
Expand All @@ -82,7 +84,7 @@ def __init__(

_layers: List[nn.Module] = [
PatchEmbedding(input_shape, patch_size, d_model),
EncoderBlock(num_layers, num_heads, d_model, dropout, nn.GELU()),
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
]
if include_top:
_layers.append(ClassifierHead(d_model, num_classes))
Expand Down Expand Up @@ -121,7 +123,7 @@ def _vit(


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer architecture as described in
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.
Expand Down
76 changes: 28 additions & 48 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@
}


class ClassifierHead(layers.Layer, NestedObject):
"""Classifier head for Vision Transformer
Args:
num_classes: number of output classes
"""

def __init__(self, num_classes: int) -> None:
super().__init__()

self.head = layers.Dense(num_classes, kernel_initializer="he_normal")

def call(self, x: tf.Tensor) -> tf.Tensor:
# (batch_size, num_classes) cls token
return self.head(x[:, 0])


class VisionTransformer(Sequential):
"""VisionTransformer architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
Expand All @@ -42,6 +59,7 @@ class VisionTransformer(Sequential):
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
Expand All @@ -54,60 +72,22 @@ def __init__(
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:

# Note: fix for onnx export
_vit = _VisionTransformer(
input_shape,
patch_size,
d_model,
num_layers,
num_heads,
dropout,
num_classes,
include_top,
)
super().__init__(_vit)
self.cfg = cfg


class _VisionTransformer(layers.Layer, NestedObject):
def __init__(
self,
input_shape: Tuple[int, int, int] = (32, 32, 3),
patch_size: Tuple[int, int] = (4, 4),
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:

super().__init__()
self.include_top = include_top

self.patch_embedding = PatchEmbedding(input_shape, patch_size, d_model)
self.encoder = EncoderBlock(num_layers, num_heads, d_model, dropout, activation_fct=GELU())

if self.include_top:
self.head = layers.Dense(num_classes, kernel_initializer="he_normal")

def __call__(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
_layers = [
PatchEmbedding(input_shape, patch_size, d_model),
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, activation_fct=GELU()),
]
if include_top:
_layers.append(ClassifierHead(num_classes))

embeddings = self.patch_embedding(x, **kwargs)
encoded = self.encoder(embeddings, **kwargs)

if self.include_top:
# (batch_size, num_classes) cls token
return self.head(encoded[:, 0], **kwargs)

return encoded
super().__init__(_layers)
self.cfg = cfg


def _vit(
Expand Down Expand Up @@ -136,7 +116,7 @@ def _vit(


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer architecture as described in
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.
Expand Down
5 changes: 3 additions & 2 deletions doctr/models/modules/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
num_layers: int,
num_heads: int,
d_model: int,
dff: int, # hidden dimension of the feedforward network
dropout: float,
activation_fct: Callable[[Any], Any] = nn.ReLU(),
) -> None:
Expand All @@ -124,7 +125,7 @@ def __init__(
[MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
)
self.position_feed_forward = nn.ModuleList(
[PositionwiseFeedForward(d_model, d_model, dropout, activation_fct) for _ in range(self.num_layers)]
[PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)]
)

def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand All @@ -151,7 +152,7 @@ def __init__(
d_model: int,
vocab_size: int,
dropout: float = 0.2,
dff: int = 2048,
dff: int = 2048, # hidden dimension of the feedforward network
maximum_position_encoding: int = 50,
) -> None:

Expand Down
5 changes: 3 additions & 2 deletions doctr/models/modules/transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
num_layers: int,
num_heads: int,
d_model: int,
dff: int, # hidden dimension of the feedforward network
dropout: float,
activation_fct: Callable[[Any], Any] = layers.ReLU(),
) -> None:
Expand All @@ -156,7 +157,7 @@ def __init__(

self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
self.position_feed_forward = [
PositionwiseFeedForward(d_model, d_model, dropout, activation_fct) for _ in range(self.num_layers)
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
]

def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor:
Expand Down Expand Up @@ -186,7 +187,7 @@ def __init__(
d_model: int,
vocab_size: int,
dropout: float = 0.2,
dff: int = 2048,
dff: int = 2048, # hidden dimension of the feedforward network
maximum_position_encoding: int = 50,
) -> None:

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"

# patchify image without convolution
# adopted from:
# adapted from:
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
# NOTE: patchify with Conv2d works only with padding="valid" correctly on smaller images
# and has currently no ONNX support so we use this workaround
Expand Down

0 comments on commit 4e763da

Please sign in to comment.