Skip to content

Commit

Permalink
[prototype] object det replacement / init contrib modules (#1534)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Apr 25, 2024
1 parent c957cf8 commit 630d925
Show file tree
Hide file tree
Showing 26 changed files with 362 additions and 432 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
run: |
coverage run -m pytest tests/common/
coverage xml -o coverage-common.xml
- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v4
with:
name: coverage-common
path: ./coverage-common.xml
Expand Down Expand Up @@ -67,7 +67,7 @@ jobs:
run: |
coverage run -m pytest tests/tensorflow/
coverage xml -o coverage-tf.xml
- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v4
with:
name: coverage-tf
path: ./coverage-tf.xml
Expand Down Expand Up @@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
flags: unittests
fail_ci_if_error: true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/public_docker_images.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
matrix:
# Must match version at https://www.python.org/ftp/python/
python: ["3.9.18", "3.10.13", "3.11.8"]
framework: ["tf", "torch", "tf,viz,html", "torch,viz,html"]
framework: ["tf", "torch", "tf,viz,html,contrib", "torch,viz,html,contrib"]
system: ["cpu", "gpu"]

# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
Expand Down
41 changes: 0 additions & 41 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -412,44 +412,3 @@ jobs:
- if: matrix.framework == 'pytorch'
name: Benchmark latency (PT)
run: python references/detection/latency_pytorch.py db_mobilenet_v3_large --it 5 --size 512

latency-object-detection:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.9"]
framework: [pytorch]
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
architecture: x64
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}
- if: matrix.framework == 'tensorflow'
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch,viz,html] --upgrade
- if: matrix.framework == 'pytorch'
name: Benchmark latency (PT)
run: python references/obj_detection/latency_pytorch.py fasterrcnn_mobilenet_v3_large_fpn --it 5 --size 512
8 changes: 4 additions & 4 deletions docs/source/getting_started/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ We strive towards reducing framework-specific dependencies to a minimum, but som
.. code:: bash
pip install "python-doctr[tf]"
# or with preinstalled packages for visualization & html support
pip install "python-doctr[tf,viz,html]"
# or with preinstalled packages for visualization & html & contrib module support
pip install "python-doctr[tf,viz,html,contib]"
.. tab:: PyTorch

.. code:: bash
pip install "python-doctr[torch]"
# or with preinstalled packages for visualization & html support
pip install "python-doctr[torch,viz,html]"
# or with preinstalled packages for visualization & html & contrib module support
pip install "python-doctr[torch,viz,html,contrib]"
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Supported datasets

using_doctr/using_models
using_doctr/using_datasets
using_doctr/using_contrib_modules
using_doctr/sharing_models
using_doctr/using_model_export
using_doctr/custom_models_training
Expand All @@ -88,6 +89,7 @@ Supported datasets
:caption: Package Reference
:hidden:

modules/contrib
modules/datasets
modules/io
modules/models
Expand Down
13 changes: 13 additions & 0 deletions docs/source/modules/contrib.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
doctr.contrib
=============

.. currentmodule:: doctr.contrib

This module contains all the available contribution modules for docTR.


Supported contribution modules
------------------------------
Here are all the available contribution modules:

.. autoclass:: ArtefactDetector
51 changes: 51 additions & 0 deletions docs/source/using_doctr/using_contrib_modules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Integrate contributions into your pipeline
==========================================

The `contrib` module provides a collection of additional features which could be relevant for your document analysis pipeline.
The following sections will give you an overview of the available modules and features.

.. currentmodule:: doctr.contrib


Available contribution modules
------------------------------

**NOTE:** To use the contrib module, you need to install the `onnxruntime` package. You can install it using the following command:

.. code:: bash
pip install python-doctr[contrib]
# Or
pip install onnxruntime # pip install onnxruntime-gpu
Here are all contribution modules that are available through docTR:

ArtefactDetection
^^^^^^^^^^^^^^^^^

The ArtefactDetection module provides a set of functions to detect artefacts in the document images, such as logos, QR codes, bar codes, etc.
It is based on the YOLOv8 architecture, which is a state-of-the-art object detection model.

.. code:: python3
from doctr.io import DocumentFile
from doctr.contrib.artefacts import ArtefactDetection
# Load the document
doc = DocumentFile.from_images(["path/to/your/image"])
detector = ArtefactDetection(batch_size=2, conf_threshold=0.5, iou_threshold=0.5)
artefacts = detector(doc)
# Visualize the detected artefacts
detector.show()
You can also use your custom trained YOLOv8 model to detect artefacts or anything else you need.
Reference: `YOLOv8 <https://github.com/ultralytics/ultralytics>`_

**NOTE:** The YOLOv8 model (no Oriented Bounding Box (OBB) inference supported yet) needs to be provided as onnx exported model with a dynamic batch size.

.. code:: python3
from doctr.contrib import ArtefactDetection
detector = ArtefactDetection(model_path="path/to/your/model.onnx", labels=["table", "figure"])
2 changes: 1 addition & 1 deletion doctr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import io, models, datasets, transforms, utils
from . import io, models, datasets, contrib, transforms, utils
from .file_utils import is_tf_available, is_torch_available
from .version import __version__ # noqa: F401
Empty file added doctr/contrib/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions doctr/contrib/artefacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, Dict, List, Optional, Tuple

import cv2
import numpy as np

from doctr.file_utils import requires_package

from .base import _BasePredictor

__all__ = ["ArtefactDetector"]

default_cfgs: Dict[str, Dict[str, Any]] = {
"yolov8_artefact": {
"input_shape": (3, 1024, 1024),
"labels": ["bar_code", "qr_code", "logo", "photo"],
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0",
},
}


class ArtefactDetector(_BasePredictor):
"""
A class to detect artefacts in images
>>> from doctr.io import DocumentFile
>>> from doctr.contrib.artefacts import ArtefactDetector
>>> doc = DocumentFile.from_images(["path/to/image.jpg"])
>>> detector = ArtefactDetector()
>>> results = detector(doc)
Args:
----
arch: the architecture to use
batch_size: the batch size to use
model_path: the path to the model to use
labels: the labels to use
input_shape: the input shape to use
mask_labels: the mask labels to use
conf_threshold: the confidence threshold to use
iou_threshold: the intersection over union threshold to use
**kwargs: additional arguments to be passed to `download_from_url`
"""

def __init__(
self,
arch: str = "yolov8_artefact",
batch_size: int = 2,
model_path: Optional[str] = None,
labels: Optional[List[str]] = None,
input_shape: Optional[Tuple[int, int, int]] = None,
conf_threshold: float = 0.5,
iou_threshold: float = 0.5,
**kwargs: Any,
) -> None:
super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs)
self.labels = labels or default_cfgs[arch]["labels"]
self.input_shape = input_shape or default_cfgs[arch]["input_shape"]
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold

def preprocess(self, img: np.ndarray) -> np.ndarray:
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)

def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
results = []

for batch in zip(output, input_images):
for out, img in zip(batch[0], batch[1]):
org_height, org_width = img.shape[:2]
width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1]
for res in out:
sample_results = []
for row in np.transpose(np.squeeze(res)):
classes_scores = row[4:]
max_score = np.amax(classes_scores)
if max_score >= self.conf_threshold:
class_id = np.argmax(classes_scores)
x, y, w, h = row[0], row[1], row[2], row[3]
# to rescaled xmin, ymin, xmax, ymax
xmin = int((x - w / 2) * width_scale)
ymin = int((y - h / 2) * height_scale)
xmax = int((x + w / 2) * width_scale)
ymax = int((y + h / 2) * height_scale)

sample_results.append({
"label": self.labels[class_id],
"confidence": float(max_score),
"box": [xmin, ymin, xmax, ymax],
})

# Filter out overlapping boxes
boxes = [res["box"] for res in sample_results]
scores = [res["confidence"] for res in sample_results]
keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type]
sample_results = [sample_results[i] for i in keep_indices]

results.append(sample_results)

self._results = results
return results

def show(self, **kwargs: Any) -> None:
"""
Display the results
Args:
----
**kwargs: additional keyword arguments to be passed to `plt.show`
"""
requires_package("matplotlib", "`.show()` requires matplotlib installed")
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# visualize the results with matplotlib
if self._results and self._inputs:
for img, res in zip(self._inputs, self._results):
plt.figure(figsize=(10, 10))
plt.imshow(img)
for obj in res:
xmin, ymin, xmax, ymax = obj["box"]
label = obj["label"]
plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red")
plt.gca().add_patch(
Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2)
)
plt.show(**kwargs)
Loading

0 comments on commit 630d925

Please sign in to comment.