Skip to content

Commit

Permalink
feat: add rotated ckpts for pytorch DBNet + fix line resolution for r…
Browse files Browse the repository at this point in the history
…otated pages (#743)

* refacto: rboxes

* feat: builder modifications

* fix: rotate_abs_boxes

* refacto: viz + metrics

* fix: flake8 + typing

* fix: debug 1

* fix: debug 2 tests

* fix: debug test 3

* fix: requested changes

* fix: test rotate

* fix: debug 4

* fix: cv2 tests metrics

* fix: revert changes

* fix: isort

* feat: add ckpt

* fix: thresh for rotated ckpt

* fix: utils + scripts

* fix: warning

* fix: warnings

* fix: line reconstruction

* fix: tests

* fix: viz

* fix: flake8

* fix: tf zoo tests

* fix: requested changes 1

* fix: tests

* fix: tests

* fix: scripts

* fix: dataset stack

* fix: typing

* fix: test zoo tf

* fix: builder

* fix: builder tests

* fix: tests merging conflicts

* fix: minor changes

* fix: crop extraction fn

* fix: typos merging conflicts

* fix: add empty line

* fix: requested changes

* fix: requested changes
  • Loading branch information
charlesmindee authored Dec 29, 2021
1 parent 76b2127 commit 1dc3374
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 66 deletions.
1 change: 1 addition & 0 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __getitem__(

img, target = self._read_sample(index)
h, w = self._get_img_shape(img)

if self.img_transforms is not None:
img = self.img_transforms(img)

Expand Down
4 changes: 1 addition & 3 deletions doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def __init__(
# Resolve the geometry using the smallest enclosing bounding box
if geometry is None:
# Check whether this is a rotated or straight box
box_resolution_fn = resolve_enclosing_rbbox if isinstance(
words[0].geometry, np.ndarray
) else resolve_enclosing_bbox
box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox
geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator, misc]

super().__init__(words=words)
Expand Down
8 changes: 4 additions & 4 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def rectify_loc_preds(
so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
"""
return np.stack(
[
page_loc_pred if orientation == 0 else np.roll(page_loc_pred, orientation, axis=0)
for orientation, page_loc_pred in zip(orientations, page_loc_preds)
],
[np.roll(
page_loc_pred,
orientation,
axis=0) for orientation, page_loc_pred in zip(orientations, page_loc_preds)],
axis=0
) if len(orientations) > 0 else None
78 changes: 30 additions & 48 deletions doctr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def _sort_boxes(boxes: np.ndarray) -> np.ndarray:
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
Returns:
indices of ordered boxes of shape (N,)
tuple: indices of ordered boxes of shape (N,), boxes
If straight boxes are passed tpo the function, boxes are unchanged
else: boxes returned are straight boxes fitted to the straightened rotated boxes
so that we fit the lines afterwards to the straigthened page
"""
if boxes.ndim == 3:
boxes = rotate_boxes(
Expand All @@ -57,41 +60,34 @@ def _sort_boxes(boxes: np.ndarray) -> np.ndarray:
orig_shape=(1024, 1024),
min_angle=5.,
)
# Points are in this order: top left, top right, bot right, bot left
return (boxes[:, 0, 0] + 2 * boxes[:, 2, 1] / np.median(
np.linalg.norm(boxes[:, 2, :] - boxes[:, 1, :])
)).argsort()
return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort()
boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1)
return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes

def _resolve_sub_lines(self, boxes: np.ndarray, words: List[int]) -> List[List[int]]:
def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]:
"""Split a line in sub_lines
Args:
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
words: list of indexes for the words of the line
boxes: bounding boxes of shape (N, 4)
word_idcs: list of indexes for the words of the line
Returns:
A list of (sub-)lines computed from the original line (words)
"""
lines = []
# Sort words horizontally
words = [words[j] for j in np.argsort(
[boxes[i, 0, 0] if len(boxes.shape) == 3 else boxes[i, 0] for i in words]
).tolist()]
word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()]

# Eventually split line horizontally
if len(words) < 2:
lines.append(words)
if len(word_idcs) < 2:
lines.append(word_idcs)
else:
sub_line = [words[0]]
for i in words[1:]:
sub_line = [word_idcs[0]]
for i in word_idcs[1:]:
horiz_break = True

prev_box = boxes[sub_line[-1]]
# Compute distance between boxes
if boxes.ndim == 3:
dist = boxes[i, 0, 0] - prev_box[0, 1]
else:
dist = boxes[i, 0] - prev_box[2]
dist = boxes[i, 0] - prev_box[2]
# If distance between boxes is lower than paragraph break, same sub-line
if dist < self.paragraph_break:
horiz_break = False
Expand All @@ -114,28 +110,23 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
Returns:
nested list of box indices
"""
# Compute median for boxes heights
y_med = np.median(boxes[:, 2, 1] - boxes[:, 1, 1] if boxes.ndim == 3 else boxes[:, 3] - boxes[:, 1])

# Sort boxes
idxs = self._sort_boxes(boxes)
# Sort boxes, and straighten the boxes if they are rotated
idxs, boxes = self._sort_boxes(boxes)

# Compute median for boxes heights
y_med = np.median(boxes[:, 3] - boxes[:, 1])

lines = []
words = [idxs[0]] # Assign the top-left word to the first line
# Define a mean y-center for the line
if boxes.ndim == 3:
y_center_sum = boxes[idxs[0]][([2, 1], [1, 1])].mean()
else:
y_center_sum = boxes[idxs[0]][[1, 3]].mean()
y_center_sum = boxes[idxs[0]][[1, 3]].mean()

for idx in idxs[1:]:
vert_break = True

# Compute y_dist
if boxes.ndim == 3:
y_dist = abs(boxes[idx][([2, 1], [1, 1])].mean() - y_center_sum / len(words))
else:
y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words))
y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words))
# If y-center of the box is close enough to mean y-center of the line, same line
if y_dist < y_med / 2:
vert_break = False
Expand All @@ -147,10 +138,7 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
y_center_sum = 0

words.append(idx)
if boxes.ndim == 3:
y_center_sum += boxes[idx][([2, 1], [1, 1])].mean()
else:
y_center_sum += boxes[idx][[1, 3]].mean()
y_center_sum += boxes[idx][[1, 3]].mean()

# Use the remaining words to form the last(s) line(s)
if len(words) > 0:
Expand Down Expand Up @@ -245,15 +233,15 @@ def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]])
# Decide whether we try to form lines
_boxes = boxes
if self.resolve_lines:
lines = self._resolve_lines(_boxes[:, :4])
lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4])
# Decide whether we try to form blocks
if self.resolve_blocks and len(lines) > 1:
_blocks = self._resolve_blocks(_boxes[:, :4], lines)
_blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines)
else:
_blocks = [lines]
else:
# Sort bounding boxes, one line for all boxes, one block for the line
lines = [self._sort_boxes(_boxes[:, :4])]
lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]]
_blocks = [lines]

blocks = [
Expand All @@ -262,7 +250,7 @@ def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]])
[
Word(
*word_preds[idx],
tuple(boxes[idx].tolist())
tuple([tuple(pt) for pt in boxes[idx].tolist()])
) if boxes.ndim == 3 else
Word(
*word_preds[idx],
Expand Down Expand Up @@ -297,7 +285,6 @@ def __call__(
Returns:
document object
"""

if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
raise ValueError("All arguments are expected to be lists of the same size")

Expand All @@ -306,14 +293,9 @@ def __call__(
if boxes[0].ndim == 3:
straight_boxes = []
# Iterate over pages
for page_boxes in boxes:
straight_boxes_page = []
for p_boxes in boxes:
# Iterate over boxes of the pages
for box in page_boxes:
xmin, xmax = np.min(box[:, 0]), np.max(box[:, 0])
ymin, ymax = np.min(box[:, 1]), np.max(box[:, 1])
straight_boxes_page.append([xmin, ymin, xmax, ymax])
straight_boxes.append(np.asarray(straight_boxes_page))
straight_boxes.append(np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1))
boxes = straight_boxes

_pages = [
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def polygon_to_box(
expanded_points = np.asarray(_points) # expand polygon
if len(expanded_points) < 1:
return None
return cv2.boundingRect(expanded_points) if self.assume_straight_pages else cv2.boxPoints(
cv2.minAreaRect(expanded_points)
return cv2.boundingRect(expanded_points) if self.assume_straight_pages else np.roll(
cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0
)

def bitmap_to_boxes(
Expand Down
37 changes: 36 additions & 1 deletion doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...utils import load_pretrained_params
from .base import DBPostProcessor, _DBNet

__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3_large']
__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3_large', 'db_resnet50_rotation']


default_cfgs: Dict[str, Dict[str, Any]] = {
Expand All @@ -39,6 +39,12 @@
'std': (0.264, 0.2749, 0.287),
'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/db_mobilenet_v3_large-fd62154b.pt',
},
'db_resnet50_rotation': {
'input_shape': (3, 1024, 1024),
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/db_resnet50-1138863a.pt',
},
}


Expand Down Expand Up @@ -363,3 +369,32 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
'features',
**kwargs,
)


def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
This model is trained with rotated documents
Example::
>>> import torch
>>> from doctr.models import db_resnet50_rotation
>>> model = db_resnet50_rotation(pretrained=True)
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
Returns:
text detection architecture
"""

return _dbnet(
'db_resnet50_rotation',
pretrained,
resnet50,
['layer1', 'layer2', 'layer3', 'layer4'],
None,
**kwargs,
)
9 changes: 8 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

if is_tf_available():
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18']
ROT_ARCHS = []
elif is_torch_available():
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18']
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'db_resnet50_rotation']
ROT_ARCHS = ['db_resnet50_rotation']


def _predictor(
Expand All @@ -30,6 +32,11 @@ def _predictor(
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

if arch not in ROT_ARCHS and not assume_straight_pages:
raise AssertionError("You are trying to use a model trained on straight pages while not assuming"
" your pages are straight. If you have only straight documents, don't pass"
f" assume_straight_pages=False, otherwise you should use one of these archs: {ROT_ARCHS}")

# Detection
_model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages)
kwargs['mean'] = kwargs.get('mean', _model.cfg['mean'])
Expand Down
4 changes: 2 additions & 2 deletions doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def estimate_page_angle(polys: np.ndarray) -> float:
"""Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the
estimated angle ccw in degrees
"""
return np.mean(np.arctan2(
(polys[:, 1, 1] - polys[:, 0, 1]),
return np.median(np.arctan(
(polys[:, 0, 1] - polys[:, 1, 1]) / # Y axis from top to bottom!
(polys[:, 1, 0] - polys[:, 0, 0])
)) * 180 / np.pi
10 changes: 6 additions & 4 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from PIL import Image, ImageDraw
from unidecode import unidecode

from .common_types import BoundingBox
from .common_types import BoundingBox, Polygon4P
from .fonts import get_font

__all__ = ['visualize_page', 'synthesize_page', 'draw_boxes']
Expand Down Expand Up @@ -116,7 +116,7 @@ def polygon_patch(


def create_obj_patch(
geometry: Union[BoundingBox, np.ndarray],
geometry: Union[BoundingBox, Polygon4P, np.ndarray],
page_dimensions: Tuple[int, int],
**kwargs: Any,
) -> patches.Patch:
Expand All @@ -130,10 +130,12 @@ def create_obj_patch(
a matplotlib Patch
"""
if isinstance(geometry, tuple):
if len(geometry) == 2:
if len(geometry) == 2: # straight word BB (2 pts)
return rect_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type]
elif len(geometry) == 4:
elif len(geometry) == 4: # rotated word BB (4 pts)
return polygon_patch(np.asarray(geometry), page_dimensions, **kwargs) # type: ignore[arg-type]
elif isinstance(geometry, np.ndarray) and geometry.shape == (4, 2): # rotated line
return polygon_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type]
raise ValueError("invalid geometry format")


Expand Down
2 changes: 1 addition & 1 deletion tests/common/test_models_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_documentbuilder():
def test_sort_boxes(input_boxes, sorted_idxs):

doc_builder = builder.DocumentBuilder()
assert doc_builder._sort_boxes(np.asarray(input_boxes)).tolist() == sorted_idxs
assert doc_builder._sort_boxes(np.asarray(input_boxes))[0].tolist() == sorted_idxs


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1dc3374

Please sign in to comment.