Skip to content

Commit

Permalink
Merge branch 'mindee:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Dec 3, 2021
2 parents eb7f59c + 8eb89e8 commit efaa2c0
Show file tree
Hide file tree
Showing 27 changed files with 241 additions and 104 deletions.
1 change: 1 addition & 0 deletions doctr/datasets/cord.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")

stem = Path(img_path).stem
_targets = []
with open(os.path.join(self.root, 'json', f"{stem}.json"), 'rb') as f:
Expand Down
4 changes: 4 additions & 0 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(

self.data: List[Tuple[str, np.ndarray]] = []
for img_name, label in labels.items():
# File existence check
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

polygons = np.asarray(label['polygons'])
if rotated_bbox:
# Switch to rotated rects
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/doc_artefacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")

boxes = np.asarray([obj['geometry'] for obj in label], dtype=np_dtype)
classes = np.asarray([self.CLASSES.index(obj['label']) for obj in label], dtype=np.long)
if rotated_bbox:
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/funsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")

stem = Path(img_path).stem
with open(os.path.join(self.root, subfolder, 'annotations', f"{stem}.json"), 'rb') as f:
data = json.load(f)
Expand Down
14 changes: 6 additions & 8 deletions doctr/datasets/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,12 @@ def __init__(
self.data: List[Tuple[str, str]] = []
with open(labels_path) as f:
labels = json.load(f)
for img_path in os.listdir(self.root):
# File existence check
if not os.path.exists(os.path.join(self.root, img_path)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_path)}")
label = labels.get(img_path)
if not isinstance(label, str):
raise KeyError("Image is not in referenced in label file")
self.data.append((img_path, label))

for img_name, label in labels.items():
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

self.data.append((img_name, label))

def merge_dataset(self, ds: AbstractDataset) -> None:
# Update data with new root for self
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/sroie.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")

stem = Path(img_path).stem
_targets = []
with open(os.path.join(self.root, 'annotations', f"{stem}.txt"), encoding='latin') as f:
Expand Down
6 changes: 6 additions & 0 deletions doctr/models/detection/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available

if is_tf_available():
from .tensorflow import *
else:
from .pytorch import * # type: ignore[misc]
37 changes: 37 additions & 0 deletions doctr/models/detection/_utils/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from torch import Tensor
from torch.nn.functional import max_pool2d

__all__ = ['erode', 'dilate']


def erode(x: Tensor, kernel_size: int) -> Tensor:
"""Performs erosion on a given tensor
Args:
x: boolean tensor of shape (N, C, H, W)
kernel_size: the size of the kernel to use for erosion
Returns:
the eroded tensor
"""
_pad = (kernel_size - 1) // 2

return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad)


def dilate(x: Tensor, kernel_size: int) -> Tensor:
"""Performs dilation on a given tensor
Args:
x: boolean tensor of shape (N, C, H, W)
kernel_size: the size of the kernel to use for dilation
Returns:
the dilated tensor
"""
_pad = (kernel_size - 1) // 2

return max_pool2d(x, kernel_size, stride=1, padding=_pad)
34 changes: 34 additions & 0 deletions doctr/models/detection/_utils/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import tensorflow as tf

__all__ = ['erode', 'dilate']


def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
"""Performs erosion on a given tensor
Args:
x: boolean tensor of shape (N, H, W, C)
kernel_size: the size of the kernel to use for erosion
Returns:
the eroded tensor
"""

return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")


def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
"""Performs dilation on a given tensor
Args:
x: boolean tensor of shape (N, H, W, C)
kernel_size: the size of the kernel to use for dilation
Returns:
the dilated tensor
"""

return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
4 changes: 2 additions & 2 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def bitmap_to_boxes(
else:
score = self.box_score(pred, contour, assume_straight_pages=False)

if self.box_thresh > score: # remove polygons with a weak objectness
if score < self.box_thresh: # remove polygons with a weak objectness
continue

if self.assume_straight_pages:
Expand Down Expand Up @@ -253,7 +253,7 @@ def draw_thresh_map(

return polygon, canvas, mask

def compute_target(
def build_target(
self,
target: List[np.ndarray],
output_shape: Tuple[int, int, int],
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False),
nn.BatchNorm2d(head_chans // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(head_chans // 4, 1, 2, stride=2),
nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
)
self.thresh_head = nn.Sequential(
conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False),
Expand Down Expand Up @@ -213,7 +213,7 @@ def compute_loss(
prob_map = torch.sigmoid(out_map.squeeze(1))
thresh_map = torch.sigmoid(thresh_map.squeeze(1))

targets = self.compute_target(target, prob_map.shape) # type: ignore[arg-type]
targets = self.build_target(target, prob_map.shape) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
Args:
feature extractor: the backbone serving as feature extractor
fpn_channels: number of channels each extracted feature maps is mapped to
num_classes: number of output channels in the segmentation map
assume_straight_pages: if True, fit straight bounding boxes only
cfg: the configuration dict of the model
"""
Expand All @@ -123,6 +124,7 @@ def __init__(
self,
feature_extractor: IntermediateLayerGetter,
fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
num_classes: int = 1,
assume_straight_pages: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -144,7 +146,7 @@ def __init__(
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.Conv2DTranspose(1, 2, strides=2, kernel_initializer='he_normal'),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer='he_normal'),
]
)
self.threshold_head = keras.Sequential(
Expand All @@ -153,7 +155,7 @@ def __init__(
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.Conv2DTranspose(1, 2, strides=2, kernel_initializer='he_normal'),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer='he_normal'),
]
)

Expand All @@ -180,7 +182,7 @@ def compute_loss(
prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1]))
thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1]))

seg_target, seg_mask, thresh_target, thresh_mask = self.compute_target(target, out_map.shape[:3])
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[:3])
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
Expand Down
42 changes: 27 additions & 15 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cv2
import numpy as np

from doctr.file_utils import is_tf_available
from doctr.models.core import BaseModel
from doctr.utils.geometry import fit_rbbox, rbbox_to_polygon

Expand Down Expand Up @@ -71,7 +72,7 @@ def bitmap_to_boxes(
else:
score = self.box_score(pred, contour, assume_straight_pages=False)

if self.box_thresh > score: # remove polygons with a weak objectness
if score < self.box_thresh: # remove polygons with a weak objectness
continue

if self.assume_straight_pages:
Expand Down Expand Up @@ -105,24 +106,27 @@ class _LinkNet(BaseModel):
min_size_box: int = 3
assume_straight_pages: bool = True

def compute_target(
def build_target(
self,
target: List[np.ndarray],
output_shape: Tuple[int, int, int],
output_shape: Tuple[int, int],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

if any(t.dtype != np.float32 for t in target):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target):
raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")

h, w = output_shape
target_shape = (len(target), h, w, 1)

if self.assume_straight_pages:
seg_target = np.zeros(output_shape, dtype=bool)
edge_mask = np.zeros(output_shape, dtype=bool)
seg_target = np.zeros(target_shape, dtype=bool)
edge_mask = np.zeros(target_shape, dtype=bool)
else:
seg_target = np.zeros(output_shape, dtype=np.uint8)
seg_target = np.zeros(target_shape, dtype=np.uint8)

seg_mask = np.ones(output_shape, dtype=bool)
seg_mask = np.ones(target_shape, dtype=bool)

for idx, _target in enumerate(target):
# Draw each polygon on gt
Expand All @@ -132,8 +136,8 @@ def compute_target(

# Absolute bounding boxes
abs_boxes = _target.copy()
abs_boxes[:, [0, 2]] *= output_shape[-1]
abs_boxes[:, [1, 3]] *= output_shape[-2]
abs_boxes[:, [0, 2]] *= w
abs_boxes[:, [1, 3]] *= h
abs_boxes = abs_boxes.round().astype(np.int32)

if abs_boxes.shape[1] == 5:
Expand All @@ -155,11 +159,19 @@ def compute_target(
cv2.fillPoly(seg_target[idx], [poly.astype(np.int32)], 1)
else:
seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True
# fill the 2 vertical edges
edge_mask[idx, max(0, box[1] - 1): min(box[1] + 1, box[3]), box[0]: box[2] + 1] = True
edge_mask[idx, max(box[1] + 1, box[3]): min(output_shape[1], box[3] + 2), box[0]: box[2] + 1] = True
# fill the 2 horizontal edges
edge_mask[idx, box[1]: box[3] + 1, max(0, box[0] - 1): min(box[0] + 1, box[2])] = True
edge_mask[idx, box[1]: box[3] + 1, max(box[0] + 1, box[2]): min(output_shape[2], box[2] + 2)] = True
# top edge
edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True
# bot edge
edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True
# left edge
edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True
# right edge
edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True

# Don't forget to switch back to channel first if PyTorch is used
if not is_tf_available():
seg_target = seg_target.transpose(0, 3, 1, 2)
seg_mask = seg_mask.transpose(0, 3, 1, 2)
edge_mask = edge_mask.transpose(0, 3, 1, 2)

return seg_target, seg_mask, edge_mask
44 changes: 12 additions & 32 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def forward(
target: Optional[List[np.ndarray]] = None,
return_model_output: bool = False,
return_boxes: bool = False,
focal_loss: bool = True,
**kwargs: Any,
) -> Dict[str, Any]:

Expand All @@ -185,7 +184,7 @@ def forward(
out["preds"] = self.postprocessor(prob_map.squeeze(1).detach().cpu().numpy())

if target is not None:
loss = self.compute_loss(logits, target, focal_loss)
loss = self.compute_loss(logits, target)
out['loss'] = loss

return out
Expand All @@ -194,53 +193,34 @@ def compute_loss(
self,
out_map: torch.Tensor,
target: List[np.ndarray],
focal_loss: bool = False,
alpha: float = .5,
gamma: float = 2.,
edge_factor: float = 2.,
) -> torch.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
<https://github.com/tensorflow/addons/>`_.
Args:
out_map: output feature map of the model of shape N x H x W x 1
out_map: output feature map of the model of shape (N, 1, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
focal_loss: if True, use focal loss instead of BCE
edge_factor: boost factor for box edges (in case of BCE)
alpha: balancing factor in the focal loss formula
gamma: modulating factor in the focal loss formula
Returns:
A loss tensor
"""
targets = self.compute_target(target, out_map.shape) # type: ignore[arg-type]
seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(targets[0]).to(dtype=out_map.dtype), torch.from_numpy(targets[1])
seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask)
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
edge_mask = torch.from_numpy(targets[2]).to(out_map.device)
if edge_factor > 0:
edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device)

# Get the cross_entropy for each entry
bce = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')[seg_mask]
loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')

if focal_loss:
if gamma and gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")

# Convert logits to prob, compute gamma factor
pred_prob = torch.sigmoid(out_map)[seg_mask]
p_t = (seg_target[seg_mask] * pred_prob) + ((1 - seg_target[seg_mask]) * (1 - pred_prob))

# Compute alpha factor
alpha_factor = seg_target[seg_mask] * alpha + (1 - seg_target[seg_mask]) * (1 - alpha)

# compute the final loss
loss = (alpha_factor * (1. - p_t) ** gamma * bce).mean()

else:
# Compute BCE loss with highlighted edges
loss = ((1 + (edge_factor - 1) * edge_mask) * bce).mean()

return loss
# Compute BCE loss with highlighted edges
if edge_factor > 0:
loss = ((1 + (edge_factor - 1) * edge_mask) * loss)
# Only consider contributions overlaping the mask
return loss[seg_mask].mean()


def _linknet(arch: str, pretrained: bool, pretrained_backbone: bool = False, **kwargs: Any) -> LinkNet:
Expand Down
Loading

0 comments on commit efaa2c0

Please sign in to comment.