Skip to content

Commit

Permalink
Merge branch 'main' into doc-update
Browse files Browse the repository at this point in the history
  • Loading branch information
fg-mindee committed Sep 16, 2021
2 parents 16439db + cfc329f commit a4937e9
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 34 deletions.
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,26 @@ doc = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images()
result = model(doc)
```

To make sense of your model's predictions, you can visualize them as follows:
To make sense of your model's predictions, you can visualize them interactively as follows:

```python
result.show(doc)
```

![DocTR example](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif)
![Visualization sample](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif)

The ocr_predictor returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
Or even rebuild the original document from its predictions:

```python
import matplotlib.pyplot as plt

plt.imshow(result.synthesize()); plt.axis('off'); plt.show()
```

![Synthesis sample](https://github.com/mindee/doctr/releases/download/v0.3.1/synthesized_sample.png)


The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/io.html#document-structure):

You can also export them as a nested dict, more appropriate for JSON format:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Easy-to-use functions to make sense of your model's predictions.

.. autofunction:: visualize_page

.. autofunction:: synthesize_page


.. _metrics:

Expand Down
12 changes: 2 additions & 10 deletions doctr/datasets/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform

from doctr.io.image import tensor_from_pil
from doctr.utils.fonts import get_font
from ..datasets import AbstractDataset


Expand All @@ -31,16 +32,7 @@ def synthesize_char_img(char: str, size: int = 32, font_family: Optional[str] =
d = ImageDraw.Draw(img)

# Draw the character
if font_family is None:
try:
font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", size)
except OSError:
font = ImageFont.load_default()
logging.warning("unable to load specific font families. Loading default PIL font,"
"font size issues may be expected."
"To prevent this, it is recommended to specify the value of 'font_family'.")
else:
font = ImageFont.truetype(font_family, size)
font = get_font(font_family, size)
d.text((4, 0), char, font=font, fill=(255, 255, 255))

return img
Expand Down
20 changes: 19 additions & 1 deletion doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Tuple, Dict, List, Any, Optional, Union

from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
from doctr.utils.visualization import visualize_page
from doctr.utils.visualization import visualize_page, synthesize_page
from doctr.utils.common_types import BoundingBox, RotatedBbox
from doctr.utils.repr import NestedObject

Expand Down Expand Up @@ -244,6 +244,15 @@ def show(
visualize_page(self.export(), page, interactive=interactive)
plt.show(**kwargs)

def synthesize(self, **kwargs) -> np.ndarray:
"""Synthesize the page from the predictions
Returns:
synthesized page
"""

return synthesize_page(self.export(), **kwargs)

@classmethod
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
Expand Down Expand Up @@ -280,6 +289,15 @@ def show(self, pages: List[np.ndarray], **kwargs) -> None:
for img, result in zip(pages, self.pages):
result.show(img, **kwargs)

def synthesize(self, **kwargs) -> List[np.ndarray]:
"""Synthesize all pages from their predictions
Returns:
list of synthesized pages
"""

return [page.synthesize() for page in self.pages]

@classmethod
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .sar import *
from .zoo import *

del utils
del utils # type: ignore[name-defined]
15 changes: 10 additions & 5 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ def extra_repr(self) -> str:

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
h, w = img.shape[:2]
random_scale = random.uniform(self.scale[0], self.scale[1])
random_ratio = random.uniform(self.ratio[0], self.ratio[1])
crop_h = math.sqrt(random_scale * random_ratio)
crop_w = math.sqrt(random_scale / random_ratio)
scale = random.uniform(self.scale[0], self.scale[1])
ratio = random.uniform(self.ratio[0], self.ratio[1])
crop_h = math.sqrt(scale * ratio)
crop_w = math.sqrt(scale / ratio)
start_x, start_y = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h)
crop_box = (int(start_x * w), int(start_y * h), int((start_x + crop_w) * w), int((start_y + crop_h) * h))
crop_box = (
max(0, int(round(start_x * w))),
max(0, int(round(start_y * h))),
min(int(round((start_x + crop_w) * w)), w - 1),
min(int(round((start_y + crop_h) * h)), h - 1)
)
croped_img, crop_boxes = F.crop_detection(img, target["boxes"], crop_box)
return croped_img, dict(boxes=crop_boxes)
28 changes: 28 additions & 0 deletions doctr/utils/fonts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import platform
import logging
from PIL import ImageFont
from typing import Optional

__all__ = ['get_font']


def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont:

# Font selection
if font_family is None:
try:
font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size)
except OSError:
font = ImageFont.load_default()
logging.warning("unable to load recommended font family. Loading default PIL font,"
"font size issues may be expected."
"To prevent this, it is recommended to specify the value of 'font_family'.")
else:
font = ImageFont.truetype(font_family, font_size)

return font
21 changes: 10 additions & 11 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing import Tuple, List, Dict, Any, Union, Optional

from .common_types import BoundingBox, RotatedBbox
from .fonts import get_font

__all__ = ['visualize_page', 'synthetize_page', 'draw_boxes']
__all__ = ['visualize_page', 'synthesize_page', 'draw_boxes']


def rect_patch(
Expand Down Expand Up @@ -242,21 +243,24 @@ def visualize_page(
return fig


def synthetize_page(
def synthesize_page(
page: Dict[str, Any],
draw_proba: bool = False,
font_size: int = 13,
font_family: Optional[str] = None,
) -> np.ndarray:
"""Draw a the content of the element page (OCR response) on a blank page.
Args:
page: exported Page object to represent
draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0
font_size: size of the font, default font = 13
font_family: family of the font
Return:
A np array (drawn page)
the synthesized page
"""

# Draw template
h, w = page["dimensions"]
response = 255 * np.ones((h, w, 3), dtype=np.int32)
Expand All @@ -271,16 +275,11 @@ def synthetize_page(
ymin, ymax = int(h * ymin), int(h * ymax)

# White drawing context adapted to font size, 0.75 factor to convert pts --> pix
h_box, w_box = ymax - ymin, xmax - xmin
h_font, w_font = font_size, int(font_size * w_box / (h_box * 0.75))
img = Image.new('RGB', (w_font, h_font), color=(255, 255, 255))
font = get_font(font_family, int(0.75 * (ymax - ymin)))
img = Image.new('RGB', (xmax - xmin, ymax - ymin), color=(255, 255, 255))
d = ImageDraw.Draw(img)

# Draw in black the value of the word
d.text((0, 0), word["value"], font=ImageFont.load_default(), fill=(0, 0, 0))

# Resize back to box size
img = img.resize((w_box, h_box), Image.NEAREST)
d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))

# Colorize if draw_proba
if draw_proba:
Expand Down
9 changes: 9 additions & 0 deletions test/common/test_io_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def test_page():
# Show
page.show(np.zeros((256, 256, 3), dtype=np.uint8), block=False)

# Synthesize
img = page.synthesize()
assert isinstance(img, np.ndarray)
assert img.shape == (*page_size, 3)


def test_document():
pages = _mock_pages()
Expand All @@ -214,3 +219,7 @@ def test_document():

# Show
doc.show([np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(len(pages))], block=False)

# Synthesize
img_list = doc.synthesize()
assert isinstance(img_list, list) and len(img_list) == len(pages)
11 changes: 11 additions & 0 deletions test/common/test_utils_fonts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from PIL.ImageFont import ImageFont

from doctr.utils.fonts import get_font


def test_get_font():

# Attempts to load recommended OS font
font = get_font()

assert isinstance(font, ImageFont)
8 changes: 5 additions & 3 deletions test/common/test_utils_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def test_visualize_page():
visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100))


def test_draw_page():
def test_synthesize_page():
pages = _mock_pages()
visualization.synthetize_page(pages[0].export(), draw_proba=True)
visualization.synthetize_page(pages[0].export(), draw_proba=False)
visualization.synthesize_page(pages[0].export(), draw_proba=False)
render = visualization.synthesize_page(pages[0].export(), draw_proba=True)
assert isinstance(render, np.ndarray)
assert render.shape == (*pages[0].dimensions, 3)


def test_draw_boxes():
Expand Down

0 comments on commit a4937e9

Please sign in to comment.