-
Notifications
You must be signed in to change notification settings - Fork 463
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[prototype] object det replacement / init contrib modules (#1534)
- Loading branch information
1 parent
c957cf8
commit 630d925
Showing
26 changed files
with
362 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
doctr.contrib | ||
============= | ||
|
||
.. currentmodule:: doctr.contrib | ||
|
||
This module contains all the available contribution modules for docTR. | ||
|
||
|
||
Supported contribution modules | ||
------------------------------ | ||
Here are all the available contribution modules: | ||
|
||
.. autoclass:: ArtefactDetector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
Integrate contributions into your pipeline | ||
========================================== | ||
|
||
The `contrib` module provides a collection of additional features which could be relevant for your document analysis pipeline. | ||
The following sections will give you an overview of the available modules and features. | ||
|
||
.. currentmodule:: doctr.contrib | ||
|
||
|
||
Available contribution modules | ||
------------------------------ | ||
|
||
**NOTE:** To use the contrib module, you need to install the `onnxruntime` package. You can install it using the following command: | ||
|
||
.. code:: bash | ||
pip install python-doctr[contrib] | ||
# Or | ||
pip install onnxruntime # pip install onnxruntime-gpu | ||
Here are all contribution modules that are available through docTR: | ||
|
||
ArtefactDetection | ||
^^^^^^^^^^^^^^^^^ | ||
|
||
The ArtefactDetection module provides a set of functions to detect artefacts in the document images, such as logos, QR codes, bar codes, etc. | ||
It is based on the YOLOv8 architecture, which is a state-of-the-art object detection model. | ||
|
||
.. code:: python3 | ||
from doctr.io import DocumentFile | ||
from doctr.contrib.artefacts import ArtefactDetection | ||
# Load the document | ||
doc = DocumentFile.from_images(["path/to/your/image"]) | ||
detector = ArtefactDetection(batch_size=2, conf_threshold=0.5, iou_threshold=0.5) | ||
artefacts = detector(doc) | ||
# Visualize the detected artefacts | ||
detector.show() | ||
You can also use your custom trained YOLOv8 model to detect artefacts or anything else you need. | ||
Reference: `YOLOv8 <https://github.com/ultralytics/ultralytics>`_ | ||
|
||
**NOTE:** The YOLOv8 model (no Oriented Bounding Box (OBB) inference supported yet) needs to be provided as onnx exported model with a dynamic batch size. | ||
|
||
.. code:: python3 | ||
from doctr.contrib import ArtefactDetection | ||
detector = ArtefactDetection(model_path="path/to/your/model.onnx", labels=["table", "figure"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import io, models, datasets, transforms, utils | ||
from . import io, models, datasets, contrib, transforms, utils | ||
from .file_utils import is_tf_available, is_torch_available | ||
from .version import __version__ # noqa: F401 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (C) 2021-2024, Mindee. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | ||
|
||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from doctr.file_utils import requires_package | ||
|
||
from .base import _BasePredictor | ||
|
||
__all__ = ["ArtefactDetector"] | ||
|
||
default_cfgs: Dict[str, Dict[str, Any]] = { | ||
"yolov8_artefact": { | ||
"input_shape": (3, 1024, 1024), | ||
"labels": ["bar_code", "qr_code", "logo", "photo"], | ||
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0", | ||
}, | ||
} | ||
|
||
|
||
class ArtefactDetector(_BasePredictor): | ||
""" | ||
A class to detect artefacts in images | ||
>>> from doctr.io import DocumentFile | ||
>>> from doctr.contrib.artefacts import ArtefactDetector | ||
>>> doc = DocumentFile.from_images(["path/to/image.jpg"]) | ||
>>> detector = ArtefactDetector() | ||
>>> results = detector(doc) | ||
Args: | ||
---- | ||
arch: the architecture to use | ||
batch_size: the batch size to use | ||
model_path: the path to the model to use | ||
labels: the labels to use | ||
input_shape: the input shape to use | ||
mask_labels: the mask labels to use | ||
conf_threshold: the confidence threshold to use | ||
iou_threshold: the intersection over union threshold to use | ||
**kwargs: additional arguments to be passed to `download_from_url` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
arch: str = "yolov8_artefact", | ||
batch_size: int = 2, | ||
model_path: Optional[str] = None, | ||
labels: Optional[List[str]] = None, | ||
input_shape: Optional[Tuple[int, int, int]] = None, | ||
conf_threshold: float = 0.5, | ||
iou_threshold: float = 0.5, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs) | ||
self.labels = labels or default_cfgs[arch]["labels"] | ||
self.input_shape = input_shape or default_cfgs[arch]["input_shape"] | ||
self.conf_threshold = conf_threshold | ||
self.iou_threshold = iou_threshold | ||
|
||
def preprocess(self, img: np.ndarray) -> np.ndarray: | ||
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0) | ||
|
||
def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]: | ||
results = [] | ||
|
||
for batch in zip(output, input_images): | ||
for out, img in zip(batch[0], batch[1]): | ||
org_height, org_width = img.shape[:2] | ||
width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1] | ||
for res in out: | ||
sample_results = [] | ||
for row in np.transpose(np.squeeze(res)): | ||
classes_scores = row[4:] | ||
max_score = np.amax(classes_scores) | ||
if max_score >= self.conf_threshold: | ||
class_id = np.argmax(classes_scores) | ||
x, y, w, h = row[0], row[1], row[2], row[3] | ||
# to rescaled xmin, ymin, xmax, ymax | ||
xmin = int((x - w / 2) * width_scale) | ||
ymin = int((y - h / 2) * height_scale) | ||
xmax = int((x + w / 2) * width_scale) | ||
ymax = int((y + h / 2) * height_scale) | ||
|
||
sample_results.append({ | ||
"label": self.labels[class_id], | ||
"confidence": float(max_score), | ||
"box": [xmin, ymin, xmax, ymax], | ||
}) | ||
|
||
# Filter out overlapping boxes | ||
boxes = [res["box"] for res in sample_results] | ||
scores = [res["confidence"] for res in sample_results] | ||
keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type] | ||
sample_results = [sample_results[i] for i in keep_indices] | ||
|
||
results.append(sample_results) | ||
|
||
self._results = results | ||
return results | ||
|
||
def show(self, **kwargs: Any) -> None: | ||
""" | ||
Display the results | ||
Args: | ||
---- | ||
**kwargs: additional keyword arguments to be passed to `plt.show` | ||
""" | ||
requires_package("matplotlib", "`.show()` requires matplotlib installed") | ||
import matplotlib.pyplot as plt | ||
from matplotlib.patches import Rectangle | ||
|
||
# visualize the results with matplotlib | ||
if self._results and self._inputs: | ||
for img, res in zip(self._inputs, self._results): | ||
plt.figure(figsize=(10, 10)) | ||
plt.imshow(img) | ||
for obj in res: | ||
xmin, ymin, xmax, ymax = obj["box"] | ||
label = obj["label"] | ||
plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red") | ||
plt.gca().add_patch( | ||
Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2) | ||
) | ||
plt.show(**kwargs) |
Oops, something went wrong.