diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index d016c2068e..b55089c877 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -61,6 +61,7 @@ def __getitem__( img, target = self._read_sample(index) h, w = self._get_img_shape(img) + if self.img_transforms is not None: img = self.img_transforms(img) diff --git a/doctr/io/elements.py b/doctr/io/elements.py index de3e9cd1bc..61bfdbaf68 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -140,9 +140,7 @@ def __init__( # Resolve the geometry using the smallest enclosing bounding box if geometry is None: # Check whether this is a rotated or straight box - box_resolution_fn = resolve_enclosing_rbbox if isinstance( - words[0].geometry, np.ndarray - ) else resolve_enclosing_bbox + box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator, misc] super().__init__(words=words) diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index 6433c1090b..b521ada1e6 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -214,9 +214,9 @@ def rectify_loc_preds( so that the points are in this order: top L, top R, bot R, bot L if the crop is readable """ return np.stack( - [ - page_loc_pred if orientation == 0 else np.roll(page_loc_pred, orientation, axis=0) - for orientation, page_loc_pred in zip(orientations, page_loc_preds) - ], + [np.roll( + page_loc_pred, + orientation, + axis=0) for orientation, page_loc_pred in zip(orientations, page_loc_preds)], axis=0 ) if len(orientations) > 0 else None diff --git a/doctr/models/builder.py b/doctr/models/builder.py index c20ae0a518..3b539b0afd 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -48,7 +48,10 @@ def _sort_boxes(boxes: np.ndarray) -> np.ndarray: boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox) Returns: - indices of ordered boxes of shape (N,) + tuple: indices of ordered boxes of shape (N,), boxes + If straight boxes are passed tpo the function, boxes are unchanged + else: boxes returned are straight boxes fitted to the straightened rotated boxes + so that we fit the lines afterwards to the straigthened page """ if boxes.ndim == 3: boxes = rotate_boxes( @@ -57,41 +60,34 @@ def _sort_boxes(boxes: np.ndarray) -> np.ndarray: orig_shape=(1024, 1024), min_angle=5., ) - # Points are in this order: top left, top right, bot right, bot left - return (boxes[:, 0, 0] + 2 * boxes[:, 2, 1] / np.median( - np.linalg.norm(boxes[:, 2, :] - boxes[:, 1, :]) - )).argsort() - return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort() + boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1) + return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes - def _resolve_sub_lines(self, boxes: np.ndarray, words: List[int]) -> List[List[int]]: + def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]: """Split a line in sub_lines Args: - boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox - words: list of indexes for the words of the line + boxes: bounding boxes of shape (N, 4) + word_idcs: list of indexes for the words of the line Returns: A list of (sub-)lines computed from the original line (words) """ lines = [] # Sort words horizontally - words = [words[j] for j in np.argsort( - [boxes[i, 0, 0] if len(boxes.shape) == 3 else boxes[i, 0] for i in words] - ).tolist()] + word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()] + # Eventually split line horizontally - if len(words) < 2: - lines.append(words) + if len(word_idcs) < 2: + lines.append(word_idcs) else: - sub_line = [words[0]] - for i in words[1:]: + sub_line = [word_idcs[0]] + for i in word_idcs[1:]: horiz_break = True prev_box = boxes[sub_line[-1]] # Compute distance between boxes - if boxes.ndim == 3: - dist = boxes[i, 0, 0] - prev_box[0, 1] - else: - dist = boxes[i, 0] - prev_box[2] + dist = boxes[i, 0] - prev_box[2] # If distance between boxes is lower than paragraph break, same sub-line if dist < self.paragraph_break: horiz_break = False @@ -114,28 +110,23 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: Returns: nested list of box indices """ - # Compute median for boxes heights - y_med = np.median(boxes[:, 2, 1] - boxes[:, 1, 1] if boxes.ndim == 3 else boxes[:, 3] - boxes[:, 1]) - # Sort boxes - idxs = self._sort_boxes(boxes) + # Sort boxes, and straighten the boxes if they are rotated + idxs, boxes = self._sort_boxes(boxes) + + # Compute median for boxes heights + y_med = np.median(boxes[:, 3] - boxes[:, 1]) lines = [] words = [idxs[0]] # Assign the top-left word to the first line # Define a mean y-center for the line - if boxes.ndim == 3: - y_center_sum = boxes[idxs[0]][([2, 1], [1, 1])].mean() - else: - y_center_sum = boxes[idxs[0]][[1, 3]].mean() + y_center_sum = boxes[idxs[0]][[1, 3]].mean() for idx in idxs[1:]: vert_break = True # Compute y_dist - if boxes.ndim == 3: - y_dist = abs(boxes[idx][([2, 1], [1, 1])].mean() - y_center_sum / len(words)) - else: - y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words)) + y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words)) # If y-center of the box is close enough to mean y-center of the line, same line if y_dist < y_med / 2: vert_break = False @@ -147,10 +138,7 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: y_center_sum = 0 words.append(idx) - if boxes.ndim == 3: - y_center_sum += boxes[idx][([2, 1], [1, 1])].mean() - else: - y_center_sum += boxes[idx][[1, 3]].mean() + y_center_sum += boxes[idx][[1, 3]].mean() # Use the remaining words to form the last(s) line(s) if len(words) > 0: @@ -245,15 +233,15 @@ def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]]) # Decide whether we try to form lines _boxes = boxes if self.resolve_lines: - lines = self._resolve_lines(_boxes[:, :4]) + lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4]) # Decide whether we try to form blocks if self.resolve_blocks and len(lines) > 1: - _blocks = self._resolve_blocks(_boxes[:, :4], lines) + _blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines) else: _blocks = [lines] else: # Sort bounding boxes, one line for all boxes, one block for the line - lines = [self._sort_boxes(_boxes[:, :4])] + lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]] _blocks = [lines] blocks = [ @@ -262,7 +250,7 @@ def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]]) [ Word( *word_preds[idx], - tuple(boxes[idx].tolist()) + tuple([tuple(pt) for pt in boxes[idx].tolist()]) ) if boxes.ndim == 3 else Word( *word_preds[idx], @@ -297,7 +285,6 @@ def __call__( Returns: document object """ - if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes): raise ValueError("All arguments are expected to be lists of the same size") @@ -306,14 +293,9 @@ def __call__( if boxes[0].ndim == 3: straight_boxes = [] # Iterate over pages - for page_boxes in boxes: - straight_boxes_page = [] + for p_boxes in boxes: # Iterate over boxes of the pages - for box in page_boxes: - xmin, xmax = np.min(box[:, 0]), np.max(box[:, 0]) - ymin, ymax = np.min(box[:, 1]), np.max(box[:, 1]) - straight_boxes_page.append([xmin, ymin, xmax, ymax]) - straight_boxes.append(np.asarray(straight_boxes_page)) + straight_boxes.append(np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1)) boxes = straight_boxes _pages = [ diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 5ae007f714..9ebc223287 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -83,8 +83,8 @@ def polygon_to_box( expanded_points = np.asarray(_points) # expand polygon if len(expanded_points) < 1: return None - return cv2.boundingRect(expanded_points) if self.assume_straight_pages else cv2.boxPoints( - cv2.minAreaRect(expanded_points) + return cv2.boundingRect(expanded_points) if self.assume_straight_pages else np.roll( + cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0 ) def bitmap_to_boxes( diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 19dedf32ed..1d1fc848ce 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -17,7 +17,7 @@ from ...utils import load_pretrained_params from .base import DBPostProcessor, _DBNet -__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3_large'] +__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3_large', 'db_resnet50_rotation'] default_cfgs: Dict[str, Dict[str, Any]] = { @@ -39,6 +39,12 @@ 'std': (0.264, 0.2749, 0.287), 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/db_mobilenet_v3_large-fd62154b.pt', }, + 'db_resnet50_rotation': { + 'input_shape': (3, 1024, 1024), + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/db_resnet50-1138863a.pt', + }, } @@ -363,3 +369,32 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: 'features', **kwargs, ) + + +def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + This model is trained with rotated documents + + Example:: + >>> import torch + >>> from doctr.models import db_resnet50_rotation + >>> model = db_resnet50_rotation(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _dbnet( + 'db_resnet50_rotation', + pretrained, + resnet50, + ['layer1', 'layer2', 'layer3', 'layer4'], + None, + **kwargs, + ) diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index ae84e2d1c3..af20d6671a 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -16,8 +16,10 @@ if is_tf_available(): ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18'] + ROT_ARCHS = [] elif is_torch_available(): - ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18'] + ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'db_resnet50_rotation'] + ROT_ARCHS = ['db_resnet50_rotation'] def _predictor( @@ -30,6 +32,11 @@ def _predictor( if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") + if arch not in ROT_ARCHS and not assume_straight_pages: + raise AssertionError("You are trying to use a model trained on straight pages while not assuming" + " your pages are straight. If you have only straight documents, don't pass" + f" assume_straight_pages=False, otherwise you should use one of these archs: {ROT_ARCHS}") + # Detection _model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages) kwargs['mean'] = kwargs.get('mean', _model.cfg['mean']) diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index c5869877ff..e38c39d7a9 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -239,7 +239,7 @@ def estimate_page_angle(polys: np.ndarray) -> float: """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the estimated angle ccw in degrees """ - return np.mean(np.arctan2( - (polys[:, 1, 1] - polys[:, 0, 1]), + return np.median(np.arctan( + (polys[:, 0, 1] - polys[:, 1, 1]) / # Y axis from top to bottom! (polys[:, 1, 0] - polys[:, 0, 0]) )) * 180 / np.pi diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index edc7e5c6ff..b67feaa6cc 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -15,7 +15,7 @@ from PIL import Image, ImageDraw from unidecode import unidecode -from .common_types import BoundingBox +from .common_types import BoundingBox, Polygon4P from .fonts import get_font __all__ = ['visualize_page', 'synthesize_page', 'draw_boxes'] @@ -116,7 +116,7 @@ def polygon_patch( def create_obj_patch( - geometry: Union[BoundingBox, np.ndarray], + geometry: Union[BoundingBox, Polygon4P, np.ndarray], page_dimensions: Tuple[int, int], **kwargs: Any, ) -> patches.Patch: @@ -130,10 +130,12 @@ def create_obj_patch( a matplotlib Patch """ if isinstance(geometry, tuple): - if len(geometry) == 2: + if len(geometry) == 2: # straight word BB (2 pts) return rect_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type] - elif len(geometry) == 4: + elif len(geometry) == 4: # rotated word BB (4 pts) return polygon_patch(np.asarray(geometry), page_dimensions, **kwargs) # type: ignore[arg-type] + elif isinstance(geometry, np.ndarray) and geometry.shape == (4, 2): # rotated line + return polygon_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type] raise ValueError("invalid geometry format") diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py index 9ebb67f788..ec468c310c 100644 --- a/tests/common/test_models_builder.py +++ b/tests/common/test_models_builder.py @@ -67,7 +67,7 @@ def test_documentbuilder(): def test_sort_boxes(input_boxes, sorted_idxs): doc_builder = builder.DocumentBuilder() - assert doc_builder._sort_boxes(np.asarray(input_boxes)).tolist() == sorted_idxs + assert doc_builder._sort_boxes(np.asarray(input_boxes))[0].tolist() == sorted_idxs @pytest.mark.parametrize(