Skip to content

Commit

Permalink
Feat: Make detection training and inference Multiclass with new kie_p…
Browse files Browse the repository at this point in the history
…redictor (#1097)

* feat: add handling of multiclass format in detection dataset loading with all transforms

* fix: fix loss computation and make training work

* feat: make loss computation vectorized and change target building to handle better class ids

* feat: add multiclass to pytorch and fix tests

* feat: add doc about Pages changes and multilabel dataset for training

* feat: fix api dockerfile and make it work with new changes

* fix reference tests

* refactor: refactor invert dict list and list dict function into one simpler

* fix: style and mypy

* docs: make it more clear for new data format

* explain why python version was upped

* add assert on length of tuple

* feat: add class names can be obtained from model config

* fix: prioritize class_names from dataset over model config

* fix: fix show samples in training

* fix: add check when target is dict and all values are numpy arrays

* fix: make detection target always dict and remove unnecessary made code from it

* fix: script detection evaluation tests and dataset tests with target as dict

* fix tests also on pytorch

* feat: Add kie predictor and io elements and visualization that come with it

* fix: revert ocr predictor to old format

* fix tests and add test for kie predictor

* up project version to 0.7.0

* update api to fix it and add kie route

* fix api version

* feat: sort class names to always have the same order.

* sort imports to avoid cyclic imports

* fix class_names default, use of tf_is_available avoid and copyright dates

* feat: update readme and doc with kie predictor

* feat: add loading backbone pretrained for multiclass detection, new elements for kie predictor (#6)

* feat: ✨ add load backbone

* feat: change kie predictor out

* fix new elements for kie, dataset when class is empty and fix and add tests

* fix api kie route

* fix evaluate kie script

* fix black

* remove commented code

* update README

* fix mypy
  • Loading branch information
aminemindee authored Dec 19, 2022
1 parent b5ed162 commit e66ce01
Show file tree
Hide file tree
Showing 55 changed files with 2,213 additions and 298 deletions.
8 changes: 0 additions & 8 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Install poetry
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: 1.1.13
- name: Lock the requirements
run: |
cd api
make lock
- name: Build & run docker
run: cd api && docker-compose up -d --build
- name: Ping server
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ jobs:
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- name: Run evaluation script
run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
run: |
python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
python scripts/evaluate_kie.py db_resnet50 crnn_vgg16_bn --samples 10
test-collectenv:
runs-on: ${{ matrix.os }}
Expand Down
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ You can also export them as a nested dict, more appropriate for JSON format:
json_output = result.export()
```

### Use the KIE predictor
The KIE predictor is a more flexible predictor compared to OCR as your detection model can detect multiple classes in a document. For example, you can have a detection model to detect just dates and adresses in a document.

The KIE predictor makes it possible to use detector with multiple classes with a recognition model and to have the whole pipeline already setup for you.

```python
from doctr.io import DocumentFile
from doctr.models import kie_predictor

# Model
model = kie_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
# PDF
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
# Analyze
result = model(doc)

predictions = result.pages[0].predictions
for class_name in predictions.keys():
list_predictions = predictions[class_name]
for prediction in list_predictions:
print(f"Prediction for {class_name}: {prediction}")
```
The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class.


### If you are looking for support from the Mindee team
[![Bad OCR test detection image asking the developer if they need help](https://github.com/mindee/doctr/releases/download/v0.5.1/doctr-need-help.png)](https://mindee.com/product/doctr)

Expand Down Expand Up @@ -247,7 +272,10 @@ Looking to integrate docTR into your API? Here is a template to get you started
#### Deploy your API locally
Specific dependencies are required to run the API template, which you can install as follows:
```shell
pip install -r api/requirements.txt
cd api/
pip install poetry
make lock
pip install -r requirements.txt
```
You can now run your API locally:

Expand All @@ -262,7 +290,7 @@ PORT=8002 docker-compose up -d --build

#### What you have deployed

Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr"). Here is an example with Python to send a request to the OCR route:
Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr", "/kie"). Here is an example with Python to send a request to the OCR route:

```python
import requests
Expand Down
15 changes: 8 additions & 7 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
ENV PYTHONPATH "${PYTHONPATH}:/app"

# copy requirements file
COPY requirements.txt /app/requirements.txt
RUN apt-get update \
&& apt-get install --no-install-recommends ffmpeg libsm6 libxext6 make -y \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

COPY pyproject.toml /app/pyproject.toml
COPY Makefile /app/Makefile

RUN apt-get update \
&& apt-get install --no-install-recommends ffmpeg libsm6 libxext6 -y \
&& pip install --upgrade pip setuptools wheel \
RUN pip install --upgrade pip setuptools wheel poetry \
&& make lock \
&& pip install -r /app/requirements.txt \
&& pip cache purge \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /root/.cache/pip

# copy project
Expand Down
2 changes: 1 addition & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
lock:
poetry lock
poetry export -f requirements.txt --without-hashes --output requirements.txt
poetry export -f requirements.txt --without-hashes --dev --output requirements-dev.txt
poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt

# Run the docker
run:
Expand Down
3 changes: 2 additions & 1 deletion api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.openapi.utils import get_openapi

from app import config as cfg
from app.routes import detection, ocr, recognition
from app.routes import detection, kie, ocr, recognition

app = FastAPI(title=cfg.PROJECT_NAME, description=cfg.PROJECT_DESCRIPTION, debug=cfg.DEBUG, version=cfg.VERSION)

Expand All @@ -18,6 +18,7 @@
app.include_router(recognition.router, prefix="/recognition", tags=["recognition"])
app.include_router(detection.router, prefix="/detection", tags=["detection"])
app.include_router(ocr.router, prefix="/ocr", tags=["ocr"])
app.include_router(kie.router, prefix="/kie", tags=["kie"])


# Middleware
Expand Down
3 changes: 2 additions & 1 deletion api/app/routes/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from app.schemas import DetectionOut
from app.vision import det_predictor
from doctr.file_utils import CLASS_NAME
from doctr.io import decode_img_as_tensor

router = APIRouter()
Expand All @@ -19,4 +20,4 @@ async def text_detection(file: UploadFile = File(...)):
"""Runs docTR text detection model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
boxes = det_predictor([img])[0]
return [DetectionOut(box=box.tolist()) for box in boxes[:, :-1]]
return [DetectionOut(box=box.tolist()) for box in boxes[CLASS_NAME][:, :-1]]
29 changes: 29 additions & 0 deletions api/app/routes/kie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Dict, List

from fastapi import APIRouter, File, UploadFile, status

from app.schemas import OCROut
from app.vision import kie_predictor
from doctr.io import decode_img_as_tensor

router = APIRouter()


@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE")
async def perform_kie(file: UploadFile = File(...)):
"""Runs docTR KIE model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = kie_predictor([img])

return {
class_name: [
OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value)
for prediction in out.pages[0].predictions[class_name]
]
for class_name in out.pages[0].predictions.keys()
}
4 changes: 3 additions & 1 deletion api/app/routes/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ async def perform_ocr(file: UploadFile = File(...)):

return [
OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value)
for word in out.pages[0].blocks[0].lines[0].words
for block in out.pages[0].blocks
for line in block.lines
for word in line.words
]
3 changes: 2 additions & 1 deletion api/app/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from doctr.models import ocr_predictor
from doctr.models import kie_predictor, ocr_predictor

predictor = ocr_predictor(pretrained=True)
det_predictor = predictor.det_predictor
reco_predictor = predictor.reco_predictor
kie_predictor = kie_predictor(pretrained=True)
6 changes: 3 additions & 3 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "doctr-api"
version = "0.5.2a0"
version = "0.7.1a0"
description = "Backend template for your OCR API with docTR"
authors = ["Mindee <contact@mindee.com>"]
license = "Apache-2.0"

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
python = ">=3.8.2,<3.11" # pypdfium2 needs a python version above 3.8.2
tensorflow = ">=2.9.0,<3.0.0"
tensorflow-addons = ">=0.17.1"
python-doctr = ">=0.2.0"
python-doctr = { version = ">=0.7.0", extras = ['tf'] }
# Fastapi: minimum version required to avoid pydantic error
# cf. https://github.com/tiangolo/fastapi/issues/4168
fastapi = ">=0.73.0"
Expand Down
29 changes: 29 additions & 0 deletions api/tests/routes/test_kie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import pytest
from scipy.optimize import linear_sum_assignment

from doctr.utils.metrics import box_iou


@pytest.mark.asyncio
async def test_perform_kie(test_app_asyncio, mock_detection_image):

response = await test_app_asyncio.post("/kie", files={"file": mock_detection_image})
assert response.status_code == 200
json_response = response.json()

gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32)
gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339
gt_labels = ["Hello", "world!"]

# Check that IoU with GT if reasonable
assert isinstance(json_response, dict) and len(list(json_response.values())[0]) == gt_boxes.shape[0]
pred_boxes = np.array([elt["box"] for json_out in json_response.values() for elt in json_out])
pred_labels = np.array([elt["value"] for json_out in json_response.values() for elt in json_out])
iou_mat = box_iou(gt_boxes, pred_boxes)
gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat)
is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8
gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept]
assert gt_idxs.shape[0] == gt_boxes.shape[0]
assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs))
2 changes: 2 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ doctr.models.zoo

.. autofunction:: doctr.models.ocr_predictor

.. autofunction:: doctr.models.kie_predictor


doctr.models.factory
--------------------
Expand Down
2 changes: 1 addition & 1 deletion doctr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import datasets, io, models, transforms, utils
from . import io, datasets, models, transforms, utils
from .file_utils import is_tf_available, is_torch_available
from .version import __version__ # noqa: F401
11 changes: 10 additions & 1 deletion doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np

from doctr.file_utils import copy_tensor
from doctr.io.image import get_img_shape
from doctr.utils.data import download_from_url

Expand Down Expand Up @@ -55,7 +58,13 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = self.img_transforms(img)

if self.sample_transforms is not None:
img, target = self.sample_transforms(img, target)
if isinstance(target, dict) and all([isinstance(item, np.ndarray) for item in target.values()]):
img_transformed = copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
img = img_transformed
else:
img, target = self.sample_transforms(img, target)

return img, target

Expand Down
6 changes: 6 additions & 0 deletions doctr/datasets/datasets/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
elif isinstance(target, tuple):
assert len(target) == 2
assert isinstance(target[0], str) or isinstance(
target[0], np.ndarray
), "first element of the tuple should be a string or a numpy array"
assert isinstance(target[1], list), "second element of the tuple should be a list"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
Expand Down
6 changes: 6 additions & 0 deletions doctr/datasets/datasets/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
elif isinstance(target, tuple):
assert len(target) == 2
assert isinstance(target[0], str) or isinstance(
target[0], np.ndarray
), "first element of the tuple should be a string or a numpy array"
assert isinstance(target[1], list), "second element of the tuple should be a list"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
Expand Down
47 changes: 39 additions & 8 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import json
import os
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple, Type, Union

import numpy as np

from doctr.io.image import get_img_shape
from doctr.utils.geometry import convert_to_relative_coords
from doctr.file_utils import CLASS_NAME

from .datasets import AbstractDataset
from .utils import pre_transform_multiclass

__all__ = ["DetectionDataset"]

Expand Down Expand Up @@ -41,24 +41,55 @@ def __init__(
) -> None:
super().__init__(
img_folder,
pre_transforms=lambda img, boxes: (img, convert_to_relative_coords(boxes, get_img_shape(img))),
pre_transforms=pre_transform_multiclass,
**kwargs,
)

# File existence check
self._class_names: List = []
if not os.path.exists(label_path):
raise FileNotFoundError(f"unable to locate {label_path}")
with open(label_path, "rb") as f:
labels = json.load(f)

self.data: List[Tuple[str, np.ndarray]] = []
self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
np_dtype = np.float32
for img_name, label in labels.items():
# File existence check
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

polygons: np.ndarray = np.asarray(label["polygons"], dtype=np_dtype)
geoms = polygons if use_polygons else np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1)
geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)

self.data.append((img_name, np.asarray(geoms, dtype=np_dtype)))
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))

def format_polygons(
self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
) -> Tuple[np.ndarray, List[str]]:
"""format polygons into an array
Args:
polygons: the bounding boxes
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
np_dtype: dtype of array
Returns:
geoms: bounding boxes as np array
polygons_classes: list of classes for each bounding box
"""
if isinstance(polygons, list):
self._class_names += [CLASS_NAME]
polygons_classes = [CLASS_NAME for _ in polygons]
_polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype)
elif isinstance(polygons, dict):
self._class_names += list(polygons.keys())
polygons_classes = [k for k, v in polygons.items() for _ in v]
_polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
else:
raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
return geoms, polygons_classes

@property
def class_names(self):
return sorted(list(set(self._class_names)))
Loading

0 comments on commit e66ce01

Please sign in to comment.