diff --git a/README.md b/README.md index 6e89cf0b5a..5a2d65d8ca 100644 --- a/README.md +++ b/README.md @@ -56,15 +56,26 @@ doc = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images() result = model(doc) ``` -To make sense of your model's predictions, you can visualize them as follows: +To make sense of your model's predictions, you can visualize them interactively as follows: ```python result.show(doc) ``` -![DocTR example](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif) +![Visualization sample](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif) -The ocr_predictor returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`). +Or even rebuild the original document from its predictions: + +```python +import matplotlib.pyplot as plt + +plt.imshow(result.synthesize()); plt.axis('off'); plt.show() +``` + +![Synthesis sample](https://github.com/mindee/doctr/releases/download/v0.3.1/synthesized_sample.png) + + +The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`). To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/io.html#document-structure): You can also export them as a nested dict, more appropriate for JSON format: diff --git a/docs/source/utils.rst b/docs/source/utils.rst index d74b54130d..c15f3a786e 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -14,6 +14,8 @@ Easy-to-use functions to make sense of your model's predictions. .. autofunction:: visualize_page +.. autofunction:: synthesize_page + .. _metrics: diff --git a/doctr/datasets/classification/base.py b/doctr/datasets/classification/base.py index dd1da03574..970f6243d2 100644 --- a/doctr/datasets/classification/base.py +++ b/doctr/datasets/classification/base.py @@ -9,6 +9,7 @@ import platform from doctr.io.image import tensor_from_pil +from doctr.utils.fonts import get_font from ..datasets import AbstractDataset @@ -31,16 +32,7 @@ def synthesize_char_img(char: str, size: int = 32, font_family: Optional[str] = d = ImageDraw.Draw(img) # Draw the character - if font_family is None: - try: - font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", size) - except OSError: - font = ImageFont.load_default() - logging.warning("unable to load specific font families. Loading default PIL font," - "font size issues may be expected." - "To prevent this, it is recommended to specify the value of 'font_family'.") - else: - font = ImageFont.truetype(font_family, size) + font = get_font(font_family, size) d.text((4, 0), char, font=font, fill=(255, 255, 255)) return img diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 4ad622af5e..f477145543 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -8,7 +8,7 @@ from typing import Tuple, Dict, List, Any, Optional, Union from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox -from doctr.utils.visualization import visualize_page +from doctr.utils.visualization import visualize_page, synthesize_page from doctr.utils.common_types import BoundingBox, RotatedBbox from doctr.utils.repr import NestedObject @@ -244,6 +244,15 @@ def show( visualize_page(self.export(), page, interactive=interactive) plt.show(**kwargs) + def synthesize(self, **kwargs) -> np.ndarray: + """Synthesize the page from the predictions + + Returns: + synthesized page + """ + + return synthesize_page(self.export(), **kwargs) + @classmethod def from_dict(cls, save_dict: Dict[str, Any], **kwargs): kwargs = {k: save_dict[k] for k in cls._exported_keys} @@ -280,6 +289,15 @@ def show(self, pages: List[np.ndarray], **kwargs) -> None: for img, result in zip(pages, self.pages): result.show(img, **kwargs) + def synthesize(self, **kwargs) -> List[np.ndarray]: + """Synthesize all pages from their predictions + + Returns: + list of synthesized pages + """ + + return [page.synthesize() for page in self.pages] + @classmethod def from_dict(cls, save_dict: Dict[str, Any], **kwargs): kwargs = {k: save_dict[k] for k in cls._exported_keys} diff --git a/doctr/models/recognition/__init__.py b/doctr/models/recognition/__init__.py index 1788e5c43a..2492ab4054 100644 --- a/doctr/models/recognition/__init__.py +++ b/doctr/models/recognition/__init__.py @@ -4,4 +4,4 @@ from .sar import * from .zoo import * -del utils +del utils # type: ignore[name-defined] diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index 656909a599..a3779972bc 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -126,11 +126,16 @@ def extra_repr(self) -> str: def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]: h, w = img.shape[:2] - random_scale = random.uniform(self.scale[0], self.scale[1]) - random_ratio = random.uniform(self.ratio[0], self.ratio[1]) - crop_h = math.sqrt(random_scale * random_ratio) - crop_w = math.sqrt(random_scale / random_ratio) + scale = random.uniform(self.scale[0], self.scale[1]) + ratio = random.uniform(self.ratio[0], self.ratio[1]) + crop_h = math.sqrt(scale * ratio) + crop_w = math.sqrt(scale / ratio) start_x, start_y = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h) - crop_box = (int(start_x * w), int(start_y * h), int((start_x + crop_w) * w), int((start_y + crop_h) * h)) + crop_box = ( + max(0, int(round(start_x * w))), + max(0, int(round(start_y * h))), + min(int(round((start_x + crop_w) * w)), w - 1), + min(int(round((start_y + crop_h) * h)), h - 1) + ) croped_img, crop_boxes = F.crop_detection(img, target["boxes"], crop_box) return croped_img, dict(boxes=crop_boxes) diff --git a/doctr/utils/fonts.py b/doctr/utils/fonts.py new file mode 100644 index 0000000000..086e5962c9 --- /dev/null +++ b/doctr/utils/fonts.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import platform +import logging +from PIL import ImageFont +from typing import Optional + +__all__ = ['get_font'] + + +def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont: + + # Font selection + if font_family is None: + try: + font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size) + except OSError: + font = ImageFont.load_default() + logging.warning("unable to load recommended font family. Loading default PIL font," + "font size issues may be expected." + "To prevent this, it is recommended to specify the value of 'font_family'.") + else: + font = ImageFont.truetype(font_family, font_size) + + return font diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index e297d6ce0c..f48551d49d 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -14,8 +14,9 @@ from typing import Tuple, List, Dict, Any, Union, Optional from .common_types import BoundingBox, RotatedBbox +from .fonts import get_font -__all__ = ['visualize_page', 'synthetize_page', 'draw_boxes'] +__all__ = ['visualize_page', 'synthesize_page', 'draw_boxes'] def rect_patch( @@ -242,10 +243,11 @@ def visualize_page( return fig -def synthetize_page( +def synthesize_page( page: Dict[str, Any], draw_proba: bool = False, font_size: int = 13, + font_family: Optional[str] = None, ) -> np.ndarray: """Draw a the content of the element page (OCR response) on a blank page. @@ -253,10 +255,12 @@ def synthetize_page( page: exported Page object to represent draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 font_size: size of the font, default font = 13 + font_family: family of the font Return: - A np array (drawn page) + the synthesized page """ + # Draw template h, w = page["dimensions"] response = 255 * np.ones((h, w, 3), dtype=np.int32) @@ -271,16 +275,11 @@ def synthetize_page( ymin, ymax = int(h * ymin), int(h * ymax) # White drawing context adapted to font size, 0.75 factor to convert pts --> pix - h_box, w_box = ymax - ymin, xmax - xmin - h_font, w_font = font_size, int(font_size * w_box / (h_box * 0.75)) - img = Image.new('RGB', (w_font, h_font), color=(255, 255, 255)) + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new('RGB', (xmax - xmin, ymax - ymin), color=(255, 255, 255)) d = ImageDraw.Draw(img) - # Draw in black the value of the word - d.text((0, 0), word["value"], font=ImageFont.load_default(), fill=(0, 0, 0)) - - # Resize back to box size - img = img.resize((w_box, h_box), Image.NEAREST) + d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) # Colorize if draw_proba if draw_proba: diff --git a/test/common/test_io_elements.py b/test/common/test_io_elements.py index 2c8f8818cb..fb87c37d44 100644 --- a/test/common/test_io_elements.py +++ b/test/common/test_io_elements.py @@ -196,6 +196,11 @@ def test_page(): # Show page.show(np.zeros((256, 256, 3), dtype=np.uint8), block=False) + # Synthesize + img = page.synthesize() + assert isinstance(img, np.ndarray) + assert img.shape == (*page_size, 3) + def test_document(): pages = _mock_pages() @@ -214,3 +219,7 @@ def test_document(): # Show doc.show([np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(len(pages))], block=False) + + # Synthesize + img_list = doc.synthesize() + assert isinstance(img_list, list) and len(img_list) == len(pages) diff --git a/test/common/test_utils_fonts.py b/test/common/test_utils_fonts.py new file mode 100644 index 0000000000..467acc7208 --- /dev/null +++ b/test/common/test_utils_fonts.py @@ -0,0 +1,11 @@ +from PIL.ImageFont import ImageFont + +from doctr.utils.fonts import get_font + + +def test_get_font(): + + # Attempts to load recommended OS font + font = get_font() + + assert isinstance(font, ImageFont) diff --git a/test/common/test_utils_visualization.py b/test/common/test_utils_visualization.py index 0f6165101d..5565b21030 100644 --- a/test/common/test_utils_visualization.py +++ b/test/common/test_utils_visualization.py @@ -21,10 +21,12 @@ def test_visualize_page(): visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100)) -def test_draw_page(): +def test_synthesize_page(): pages = _mock_pages() - visualization.synthetize_page(pages[0].export(), draw_proba=True) - visualization.synthetize_page(pages[0].export(), draw_proba=False) + visualization.synthesize_page(pages[0].export(), draw_proba=False) + render = visualization.synthesize_page(pages[0].export(), draw_proba=True) + assert isinstance(render, np.ndarray) + assert render.shape == (*pages[0].dimensions, 3) def test_draw_boxes():