-
Notifications
You must be signed in to change notification settings - Fork 468
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: Make detection training and inference Multiclass with new kie_p…
…redictor (#1097) * feat: add handling of multiclass format in detection dataset loading with all transforms * fix: fix loss computation and make training work * feat: make loss computation vectorized and change target building to handle better class ids * feat: add multiclass to pytorch and fix tests * feat: add doc about Pages changes and multilabel dataset for training * feat: fix api dockerfile and make it work with new changes * fix reference tests * refactor: refactor invert dict list and list dict function into one simpler * fix: style and mypy * docs: make it more clear for new data format * explain why python version was upped * add assert on length of tuple * feat: add class names can be obtained from model config * fix: prioritize class_names from dataset over model config * fix: fix show samples in training * fix: add check when target is dict and all values are numpy arrays * fix: make detection target always dict and remove unnecessary made code from it * fix: script detection evaluation tests and dataset tests with target as dict * fix tests also on pytorch * feat: Add kie predictor and io elements and visualization that come with it * fix: revert ocr predictor to old format * fix tests and add test for kie predictor * up project version to 0.7.0 * update api to fix it and add kie route * fix api version * feat: sort class names to always have the same order. * sort imports to avoid cyclic imports * fix class_names default, use of tf_is_available avoid and copyright dates * feat: update readme and doc with kie predictor * feat: add loading backbone pretrained for multiclass detection, new elements for kie predictor (#6) * feat: ✨ add load backbone * feat: change kie predictor out * fix new elements for kie, dataset when class is empty and fix and add tests * fix api kie route * fix evaluate kie script * fix black * remove commented code * update README * fix mypy
- Loading branch information
1 parent
b5ed162
commit e66ce01
Showing
55 changed files
with
2,213 additions
and
298 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (C) 2022, Mindee. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | ||
|
||
from typing import Dict, List | ||
|
||
from fastapi import APIRouter, File, UploadFile, status | ||
|
||
from app.schemas import OCROut | ||
from app.vision import kie_predictor | ||
from doctr.io import decode_img_as_tensor | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE") | ||
async def perform_kie(file: UploadFile = File(...)): | ||
"""Runs docTR KIE model to analyze the input image""" | ||
img = decode_img_as_tensor(file.file.read()) | ||
out = kie_predictor([img]) | ||
|
||
return { | ||
class_name: [ | ||
OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value) | ||
for prediction in out.pages[0].predictions[class_name] | ||
] | ||
for class_name in out.pages[0].predictions.keys() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import numpy as np | ||
import pytest | ||
from scipy.optimize import linear_sum_assignment | ||
|
||
from doctr.utils.metrics import box_iou | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_perform_kie(test_app_asyncio, mock_detection_image): | ||
|
||
response = await test_app_asyncio.post("/kie", files={"file": mock_detection_image}) | ||
assert response.status_code == 200 | ||
json_response = response.json() | ||
|
||
gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) | ||
gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 | ||
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 | ||
gt_labels = ["Hello", "world!"] | ||
|
||
# Check that IoU with GT if reasonable | ||
assert isinstance(json_response, dict) and len(list(json_response.values())[0]) == gt_boxes.shape[0] | ||
pred_boxes = np.array([elt["box"] for json_out in json_response.values() for elt in json_out]) | ||
pred_labels = np.array([elt["value"] for json_out in json_response.values() for elt in json_out]) | ||
iou_mat = box_iou(gt_boxes, pred_boxes) | ||
gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) | ||
is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 | ||
gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept] | ||
assert gt_idxs.shape[0] == gt_boxes.shape[0] | ||
assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import datasets, io, models, transforms, utils | ||
from . import io, datasets, models, transforms, utils | ||
from .file_utils import is_tf_available, is_torch_available | ||
from .version import __version__ # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.