Skip to content

Commit

Permalink
[detection] move padding removal directly to detection (#1627)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jun 5, 2024
1 parent 8ee03e6 commit 5eda559
Show file tree
Hide file tree
Showing 19 changed files with 170 additions and 76 deletions.
8 changes: 4 additions & 4 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ should yield
"name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg",
"geometries": [
[
0.724609375,
0.8176307908857315,
0.1787109375,
0.7900390625,
0.9101580212741838,
0.2080078125
],
[
0.6748046875,
0.7471996155154171,
0.1796875,
0.7314453125,
0.8272978149561669,
0.20703125
]
]
Expand Down
20 changes: 10 additions & 10 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,31 @@ def mock_detection_response():
"box": {
"name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg",
"geometries": [
[0.724609375, 0.1787109375, 0.7900390625, 0.2080078125],
[0.6748046875, 0.1796875, 0.7314453125, 0.20703125],
[0.8176307908857315, 0.1787109375, 0.9101580212741838, 0.2080078125],
[0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125],
],
},
"poly": {
"name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg",
"geometries": [
[
0.7873152494430542,
0.9063061475753784,
0.17740710079669952,
0.7884310483932495,
0.9078840017318726,
0.20474515855312347,
0.7244035005569458,
0.8173396587371826,
0.20735852420330048,
0.7232877016067505,
0.8157618045806885,
0.18002046644687653,
],
[
0.7286394834518433,
0.8233299851417542,
0.17740298807621002,
0.7298480272293091,
0.8250390291213989,
0.2027825564146042,
0.6746810674667358,
0.7470247745513916,
0.20540954172611237,
0.67347252368927,
0.7453157305717468,
0.1800299733877182,
],
],
Expand Down
2 changes: 1 addition & 1 deletion doctr/datasets/imgur5k.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
if ann["word"] != "."
]
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type]
box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes]

if not use_polygons:
# xmin, ymin, xmax, ymax
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li
if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = cv2.medianBlur(gray_img, 5)
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment]
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

# try to merge words in lines
(h, w) = img.shape[:2]
k_x = max(1, (floor(w / 100)))
k_y = max(1, (floor(h / 100)))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment]
thresh = cv2.dilate(thresh, kernel, iterations=1)

# extract contours
contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
Expand All @@ -68,7 +68,7 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li

angles = []
for contour in contours[:n_ct]:
_, (w, h), angle = cv2.minAreaRect(contour)
_, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
angles.append(angle)
elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
Expand Down
1 change: 1 addition & 0 deletions doctr/models/detection/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from doctr.file_utils import is_tf_available
from .base import *

if is_tf_available():
from .tensorflow import *
Expand Down
66 changes: 66 additions & 0 deletions doctr/models/detection/_utils/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Dict, List

import numpy as np

__all__ = ["_remove_padding"]


def _remove_padding(
pages: List[np.ndarray],
loc_preds: List[Dict[str, np.ndarray]],
preserve_aspect_ratio: bool,
symmetric_pad: bool,
assume_straight_pages: bool,
) -> List[Dict[str, np.ndarray]]:
"""Remove padding from the localization predictions
Args:
----
pages: list of pages
loc_preds: list of localization predictions
preserve_aspect_ratio: whether the aspect ratio was preserved during padding
symmetric_pad: whether the padding was symmetric
assume_straight_pages: whether the pages are assumed to be straight
Returns:
-------
list of unpaded localization predictions
"""
if preserve_aspect_ratio:
# Rectify loc_preds to remove padding
rectified_preds = []
for page, dict_loc_preds in zip(pages, loc_preds):
for k, loc_pred in dict_loc_preds.items():
h, w = page.shape[0], page.shape[1]
if h > w:
# y unchanged, dilate x coord
if symmetric_pad:
if assume_straight_pages:
loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
else:
loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
else:
if assume_straight_pages:
loc_pred[:, [0, 2]] *= h / w
else:
loc_pred[:, :, 0] *= h / w
elif w > h:
# x unchanged, dilate y coord
if symmetric_pad:
if assume_straight_pages:
loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
else:
loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
else:
if assume_straight_pages:
loc_pred[:, [1, 3]] *= w / h
else:
loc_pred[:, :, 1] *= w / h
rectified_preds.append({k: np.clip(loc_pred, 0, 1)})
return rectified_preds
return loc_preds
2 changes: 1 addition & 1 deletion doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def bitmap_to_boxes(
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
# Check whether smallest enclosing bounding box is not too small
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
continue
# Compute objectness
if self.assume_straight_pages:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def bitmap_to_boxes(
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
# Check whether smallest enclosing bounding box is not too small
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
continue
# Compute objectness
if self.assume_straight_pages:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def bitmap_to_boxes(
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
# Check whether smallest enclosing bounding box is not too small
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
continue
# Compute objectness
if self.assume_straight_pages:
Expand Down
16 changes: 15 additions & 1 deletion doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch import nn

from doctr.models.detection._utils import _remove_padding
from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype

Expand Down Expand Up @@ -40,6 +41,11 @@ def forward(
return_maps: bool = False,
**kwargs: Any,
) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
# Extract parameters from the preprocessor
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
symmetric_pad = self.pre_processor.resize.symmetric_pad
assume_straight_pages = self.model.assume_straight_pages

# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
Expand All @@ -52,7 +58,15 @@ def forward(
predicted_batches = [
self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
]
preds = [pred for batch in predicted_batches for pred in batch["preds"]]
# Remove padding from loc predictions
preds = _remove_padding(
pages, # type: ignore[arg-type]
[pred for batch in predicted_batches for pred in batch["preds"]],
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
assume_straight_pages=assume_straight_pages,
)

if return_maps:
seg_maps = [
pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
Expand Down
16 changes: 15 additions & 1 deletion doctr/models/detection/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorflow as tf
from tensorflow import keras

from doctr.models.detection._utils import _remove_padding
from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject

Expand Down Expand Up @@ -40,6 +41,11 @@ def __call__(
return_maps: bool = False,
**kwargs: Any,
) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
# Extract parameters from the preprocessor
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
symmetric_pad = self.pre_processor.resize.symmetric_pad
assume_straight_pages = self.model.assume_straight_pages

# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
Expand All @@ -50,7 +56,15 @@ def __call__(
for batch in processed_batches
]

preds = [pred for batch in predicted_batches for pred in batch["preds"]]
# Remove padding from loc predictions
preds = _remove_padding(
pages,
[pred for batch in predicted_batches for pred in batch["preds"]],
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
assume_straight_pages=assume_straight_pages,
)

if return_maps:
seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
return preds, seg_maps
Expand Down
3 changes: 0 additions & 3 deletions doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def forward(
# Check whether crop mode should be switched to channels first
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)

# Rectify crops if aspect ratio
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} # type: ignore[arg-type]

# Apply hooks to loc_preds if any
for hook in self.hooks:
dict_loc_preds = hook(dict_loc_preds)
Expand Down
2 changes: 0 additions & 2 deletions doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def __call__(
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]

dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
# Rectify crops if aspect ratio
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}

# Apply hooks to loc_preds if any
for hook in self.hooks:
Expand Down
38 changes: 0 additions & 38 deletions doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,44 +104,6 @@ def _rectify_crops(
]
return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]

def _remove_padding(
self,
pages: List[np.ndarray],
loc_preds: List[np.ndarray],
) -> List[np.ndarray]:
if self.preserve_aspect_ratio:
# Rectify loc_preds to remove padding
rectified_preds = []
for page, loc_pred in zip(pages, loc_preds):
h, w = page.shape[0], page.shape[1]
if h > w:
# y unchanged, dilate x coord
if self.symmetric_pad:
if self.assume_straight_pages:
loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
else:
loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
else:
if self.assume_straight_pages:
loc_pred[:, [0, 2]] *= h / w
else:
loc_pred[:, :, 0] *= h / w
elif w > h:
# x unchanged, dilate y coord
if self.symmetric_pad:
if self.assume_straight_pages:
loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
else:
loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
else:
if self.assume_straight_pages:
loc_pred[:, [1, 3]] *= w / h
else:
loc_pred[:, :, 1] *= w / h
rectified_preds.append(loc_pred)
return rectified_preds
return loc_preds

@staticmethod
def _process_predictions(
loc_preds: List[np.ndarray],
Expand Down
3 changes: 0 additions & 3 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ def forward(
# Check whether crop mode should be switched to channels first
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)

# Rectify crops if aspect ratio
loc_preds = self._remove_padding(pages, loc_preds) # type: ignore[arg-type]

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)
Expand Down
3 changes: 0 additions & 3 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ def __call__(
), "Detection Model in ocr_predictor should output only one class"
loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]

# Rectify crops if aspect ratio
loc_preds = self._remove_padding(pages, loc_preds)

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)
Expand Down
2 changes: 1 addition & 1 deletion doctr/transforms/functional/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ def create_shadow_mask(
mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8)
mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0]

return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32) # type: ignore[operator]
return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
6 changes: 3 additions & 3 deletions doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024
# Convert to absolute for minAreaRect
cloud *= intermed_size
rect = cv2.minAreaRect(cloud.astype(np.int32))
return cv2.boxPoints(rect) / intermed_size # type: ignore[operator]
return cv2.boxPoints(rect) / intermed_size # type: ignore[return-value]


def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray:
Expand Down Expand Up @@ -320,7 +320,7 @@ def rotate_image(
# Pad height
else:
h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0
rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0)))
rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) # type: ignore[assignment]
if preserve_origin_shape:
# rescale
rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR)
Expand Down Expand Up @@ -453,4 +453,4 @@ def extract_rcrops(
)
for idx in range(_boxes.shape[0])
]
return crops
return crops # type: ignore[return-value]
Loading

0 comments on commit 5eda559

Please sign in to comment.