diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f19c87e358..0aa5a44976 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -29,9 +29,8 @@ jobs: python-version: ${{ matrix.python }} architecture: x64 - name: Build & run docker - run: cd api && docker-compose up -d --build + run: cd api && make lock && make run - name: Ping server run: wget --spider --tries=12 http://localhost:8080/docs - name: Run docker test - run: | - docker-compose -f api/docker-compose.yml exec --no-TTY web pytest tests/ + run: cd api && make test diff --git a/api/Dockerfile b/api/Dockerfile index a158e44721..8038ed28c8 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -15,7 +15,7 @@ RUN apt-get update \ COPY pyproject.toml /app/pyproject.toml COPY Makefile /app/Makefile -RUN pip install --upgrade pip setuptools wheel poetry \ +RUN pip install --upgrade pip setuptools wheel \ && make lock \ && pip install -r /app/requirements.txt \ && pip cache purge \ diff --git a/api/Makefile b/api/Makefile index 689931dd29..09e9841e91 100644 --- a/api/Makefile +++ b/api/Makefile @@ -3,6 +3,7 @@ .PHONY: lock run stop test # Pin the dependencies lock: + pip install poetry>=1.0 poetry lock poetry export -f requirements.txt --without-hashes --output requirements.txt poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt @@ -18,8 +19,8 @@ stop: # Run tests for the library test: docker compose up -d --build - docker cp requirements-dev.txt api_web_1:/app/requirements-dev.txt + docker cp requirements-dev.txt api_web:/app/requirements-dev.txt docker compose exec -T web pip install -r requirements-dev.txt - docker cp tests api_web_1:/app/tests - docker compose exec -T web pytest tests/ + docker cp tests api_web:/app/tests + docker compose exec -T web pytest tests/ -vv docker compose down diff --git a/api/README.md b/api/README.md index 426e191bf2..4126e808c5 100644 --- a/api/README.md +++ b/api/README.md @@ -35,16 +35,39 @@ with this snippet: ```python import requests + +headers = {"accept": "application/json"} +params = {"det_arch": "db_resnet50"} + with open('/path/to/your/img.jpg', 'rb') as f: - data = f.read() -print(requests.post("http://localhost:8080/detection", files={'file': data}).json()) + files = [ # application/pdf, image/jpeg, image/png supported + ("files", ("117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", f.read(), "image/jpeg")), + ] +print(requests.post("http://localhost:8080/detection", headers=headers, params=params, files=files).json()) ``` should yield ```json -[{'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875]}, - {'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875]}] +[ + { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "geometries": [ + [ + 0.724609375, + 0.1787109375, + 0.7900390625, + 0.2080078125 + ], + [ + 0.6748046875, + 0.1796875, + 0.7314453125, + 0.20703125 + ] + ] + } +] ``` #### Text recognition @@ -56,15 +79,27 @@ with this snippet: ```python import requests + +headers = {"accept": "application/json"} +params = {"reco_arch": "crnn_vgg16_bn"} + with open('/path/to/your/img.jpg', 'rb') as f: - data = f.read() -print(requests.post("http://localhost:8080/recognition", files={'file': data}).json()) + files = [ # application/pdf, image/jpeg, image/png supported + ("files", ("117133599-c073fa00-ada4-11eb-831b-412de4d28341.jpeg", f.read(), "image/jpeg")), + ] +print(requests.post("http://localhost:8080/recognition", headers=headers, params=params, files=files).json()) ``` should yield ```json -{'value': 'invite'} +[ + { + "name": "117133599-c073fa00-ada4-11eb-831b-412de4d28341.jpeg", + "value": "invite", + "confidence": 1.0 + } +] ``` #### End-to-end OCR @@ -76,16 +111,78 @@ with this snippet: ```python import requests + +headers = {"accept": "application/json"} +params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn"} + with open('/path/to/your/img.jpg', 'rb') as f: - data = f.read() -print(requests.post("http://localhost:8080/ocr", files={'file': data}).json()) + files = [ # application/pdf, image/jpeg, image/png supported + ("files", ("117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", f.read(), "image/jpeg")), + ] +print(requests.post("http://localhost:8080/ocr", headers=headers, params=params, files=files).json()) ``` should yield ```json -[{'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875], - 'value': 'Hello'}, - {'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875], - 'value': 'world!'}] +[ + { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "orientation": { + "value": 0, + "confidence": null + }, + "language": { + "value": null, + "confidence": null + }, + "dimensions": [2339, 1654], + "items": [ + { + "blocks": [ + { + "geometry": [ + 0.7471996155154171, + 0.1787109375, + 0.9101580212741838, + 0.2080078125 + ], + "lines": [ + { + "geometry": [ + 0.7471996155154171, + 0.1787109375, + 0.9101580212741838, + 0.2080078125 + ], + "words": [ + { + "value": "Hello", + "geometry": [ + 0.7471996155154171, + 0.1796875, + 0.8272978149561669, + 0.20703125 + ], + "confidence": 1.0 + }, + { + "value": "world!", + "geometry": [ + 0.8176307908857315, + 0.1787109375, + 0.9101580212741838, + 0.2080078125 + ], + "confidence": 1.0 + } + ] + } + ] + } + ] + } + ] + } +] ``` diff --git a/api/app/routes/detection.py b/api/app/routes/detection.py index 71c64a7c1c..e044d1f815 100644 --- a/api/app/routes/detection.py +++ b/api/app/routes/detection.py @@ -5,19 +5,31 @@ from typing import List -from fastapi import APIRouter, File, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import DetectionOut -from app.vision import det_predictor +from app.schemas import DetectionIn, DetectionOut +from app.utils import get_documents, resolve_geometry +from app.vision import init_predictor from doctr.file_utils import CLASS_NAME -from doctr.io import decode_img_as_tensor router = APIRouter() @router.post("/", response_model=List[DetectionOut], status_code=status.HTTP_200_OK, summary="Perform text detection") -async def text_detection(file: UploadFile = File(...)): +async def text_detection(request: DetectionIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR text detection model to analyze the input image""" - img = decode_img_as_tensor(file.file.read()) - boxes = det_predictor([img])[0] - return [DetectionOut(box=box.tolist()) for box in boxes[CLASS_NAME][:, :-1]] + try: + predictor = init_predictor(request) + content, filenames = await get_documents(files) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return [ + DetectionOut( + name=filename, + geometries=[ + geom[:-1].tolist() if len(geom) == 5 else resolve_geometry(geom.tolist()) for geom in doc[CLASS_NAME] + ], + ) + for doc, filename in zip(predictor(content), filenames) + ] diff --git a/api/app/routes/kie.py b/api/app/routes/kie.py index 2ef4cce4c8..46b2d92be1 100644 --- a/api/app/routes/kie.py +++ b/api/app/routes/kie.py @@ -3,27 +3,50 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Dict, List +from typing import List -from fastapi import APIRouter, File, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import OCROut -from app.vision import kie_predictor -from doctr.io import decode_img_as_tensor +from app.schemas import KIEElement, KIEIn, KIEOut +from app.utils import get_documents, resolve_geometry +from app.vision import init_predictor router = APIRouter() -@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE") -async def perform_kie(file: UploadFile = File(...)): +@router.post("/", response_model=List[KIEOut], status_code=status.HTTP_200_OK, summary="Perform KIE") +async def perform_kie(request: KIEIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR KIE model to analyze the input image""" - img = decode_img_as_tensor(file.file.read()) - out = kie_predictor([img]) - - return { - class_name: [ - OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value) - for prediction in out.pages[0].predictions[class_name] - ] - for class_name in out.pages[0].predictions.keys() - } + try: + predictor = init_predictor(request) + content, filenames = await get_documents(files) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + out = predictor(content) + + results = [ + KIEOut( + name=filenames[i], + orientation=page.orientation, + language=page.language, + dimensions=page.dimensions, + predictions=[ + KIEElement( + class_name=class_name, + items=[ + dict( + value=prediction.value, + geometry=resolve_geometry(prediction.geometry), + confidence=round(prediction.confidence, 2), + ) + for prediction in page.predictions[class_name] + ], + ) + for class_name in page.predictions.keys() + ], + ) + for i, page in enumerate(out.pages) + ] + + return results diff --git a/api/app/routes/ocr.py b/api/app/routes/ocr.py index 37bb05e85a..4c766e9f35 100644 --- a/api/app/routes/ocr.py +++ b/api/app/routes/ocr.py @@ -5,24 +5,59 @@ from typing import List -from fastapi import APIRouter, File, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import OCROut -from app.vision import predictor -from doctr.io import decode_img_as_tensor +from app.schemas import OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord +from app.utils import get_documents, resolve_geometry +from app.vision import init_predictor router = APIRouter() @router.post("/", response_model=List[OCROut], status_code=status.HTTP_200_OK, summary="Perform OCR") -async def perform_ocr(file: UploadFile = File(...)): +async def perform_ocr(request: OCRIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR OCR model to analyze the input image""" - img = decode_img_as_tensor(file.file.read()) - out = predictor([img]) - - return [ - OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value) - for block in out.pages[0].blocks - for line in block.lines - for word in line.words + try: + # generator object to list + content, filenames = await get_documents(files) + predictor = init_predictor(request) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + out = predictor(content) + + results = [ + OCROut( + name=filenames[i], + orientation=page.orientation, + language=page.language, + dimensions=page.dimensions, + items=[ + OCRPage( + blocks=[ + OCRBlock( + geometry=resolve_geometry(block.geometry), + lines=[ + OCRLine( + geometry=resolve_geometry(line.geometry), + words=[ + OCRWord( + value=word.value, + geometry=resolve_geometry(word.geometry), + confidence=round(word.confidence, 2), + ) + for word in line.words + ], + ) + for line in block.lines + ], + ) + for block in page.blocks + ] + ) + ], + ) + for i, page in enumerate(out.pages) ] + + return results diff --git a/api/app/routes/recognition.py b/api/app/routes/recognition.py index 9727424995..65de3e07ba 100644 --- a/api/app/routes/recognition.py +++ b/api/app/routes/recognition.py @@ -3,18 +3,28 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from fastapi import APIRouter, File, UploadFile, status +from typing import List -from app.schemas import RecognitionOut -from app.vision import reco_predictor -from doctr.io import decode_img_as_tensor +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status + +from app.schemas import RecognitionIn, RecognitionOut +from app.utils import get_documents +from app.vision import init_predictor router = APIRouter() -@router.post("/", response_model=RecognitionOut, status_code=status.HTTP_200_OK, summary="Perform text recognition") -async def text_recognition(file: UploadFile = File(...)): +@router.post( + "/", response_model=List[RecognitionOut], status_code=status.HTTP_200_OK, summary="Perform text recognition" +) +async def text_recognition(request: RecognitionIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR text recognition model to analyze the input image""" - img = decode_img_as_tensor(file.file.read()) - out = reco_predictor([img]) - return RecognitionOut(value=out[0][0]) + try: + predictor = init_predictor(request) + content, filenames = await get_documents(files) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return [ + RecognitionOut(name=filename, value=res[0], confidence=round(res[1], 2)) + for res, filename in zip(predictor(content), filenames) + ] diff --git a/api/app/schemas.py b/api/app/schemas.py index a5bef9cef8..8fe3fce38f 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -3,19 +3,132 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Tuple +from typing import Dict, List, Tuple, Union from pydantic import BaseModel, Field -# Recognition output +class KIEIn(BaseModel): + det_arch: str = Field(default="db_resnet50", examples=["db_resnet50"]) + reco_arch: str = Field(default="crnn_vgg16_bn", examples=["crnn_vgg16_bn"]) + assume_straight_pages: bool = Field(default=True, examples=[True]) + preserve_aspect_ratio: bool = Field(default=True, examples=[True]) + detect_orientation: bool = Field(default=False, examples=[False]) + detect_language: bool = Field(default=False, examples=[False]) + symmetric_pad: bool = Field(default=True, examples=[True]) + straighten_pages: bool = Field(default=False, examples=[False]) + det_bs: int = Field(default=2, examples=[2]) + reco_bs: int = Field(default=128, examples=[128]) + bin_thresh: float = Field(default=0.1, examples=[0.1]) + box_thresh: float = Field(default=0.1, examples=[0.1]) + + +class OCRIn(KIEIn, BaseModel): + resolve_lines: bool = Field(default=True, examples=[True]) + resolve_blocks: bool = Field(default=True, examples=[True]) + paragraph_break: float = Field(default=0.0035, examples=[0.0035]) + + +class RecognitionIn(BaseModel): + reco_arch: str = Field(default="crnn_vgg16_bn", examples=["crnn_vgg16_bn"]) + reco_bs: int = Field(default=128, examples=[128]) + + +class DetectionIn(BaseModel): + det_arch: str = Field(default="db_resnet50", examples=["db_resnet50"]) + assume_straight_pages: bool = Field(default=True, examples=[True]) + preserve_aspect_ratio: bool = Field(default=True, examples=[True]) + symmetric_pad: bool = Field(default=True, examples=[True]) + det_bs: int = Field(default=2, examples=[2]) + bin_thresh: float = Field(default=0.1, examples=[0.1]) + box_thresh: float = Field(default=0.1, examples=[0.1]) + + class RecognitionOut(BaseModel): - value: str = Field(..., example="Hello") + name: str = Field(..., examples=["example.jpg"]) + value: str = Field(..., examples=["Hello"]) + confidence: float = Field(..., examples=[0.99]) class DetectionOut(BaseModel): - box: Tuple[float, float, float, float] + name: str = Field(..., examples=["example.jpg"]) + geometries: List[List[float]] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + + +class OCRWord(BaseModel): + value: str = Field(..., examples=["example"]) + geometry: List[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + confidence: float = Field(..., examples=[0.99]) + + +class OCRLine(BaseModel): + geometry: List[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + words: List[OCRWord] = Field( + ..., examples=[{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}] + ) + + +class OCRBlock(BaseModel): + geometry: List[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + lines: List[OCRLine] = Field( + ..., + examples=[ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "words": [{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}], + } + ], + ) + + +class OCRPage(BaseModel): + blocks: List[OCRBlock] = Field( + ..., + examples=[ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "lines": [ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "words": [{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}], + } + ], + } + ], + ) + + +class OCROut(BaseModel): + name: str = Field(..., examples=["example.jpg"]) + orientation: Dict[str, Union[float, None]] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) + language: Dict[str, Union[str, float, None]] = Field(..., examples=[{"value": "en", "confidence": 0.99}]) + dimensions: Tuple[int, int] = Field(..., examples=[(100, 100)]) + items: List[OCRPage] = Field( + ..., + examples=[ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "lines": [ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "words": [{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}], + } + ], + } + ], + ) + + +class KIEElement(BaseModel): + class_name: str = Field(..., examples=["example"]) + items: List[Dict[str, Union[str, List[float], float]]] = Field( + ..., examples=[{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}] + ) -class OCROut(RecognitionOut, DetectionOut): - pass +class KIEOut(BaseModel): + name: str = Field(..., examples=["example.jpg"]) + orientation: Dict[str, Union[float, None]] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) + language: Dict[str, Union[str, float, None]] = Field(..., examples=[{"value": "en", "confidence": 0.99}]) + dimensions: Tuple[int, int] = Field(..., examples=[(100, 100)]) + predictions: List[KIEElement] diff --git a/api/app/utils.py b/api/app/utils.py new file mode 100644 index 0000000000..511a75ad9e --- /dev/null +++ b/api/app/utils.py @@ -0,0 +1,49 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, List, Tuple, Union + +import numpy as np +from fastapi import UploadFile + +from doctr.io import DocumentFile + + +def resolve_geometry( + geom: Any, +) -> Union[Tuple[float, float, float, float], Tuple[float, float, float, float, float, float, float, float]]: + if len(geom) == 4: + return (*geom[0], *geom[1], *geom[2], *geom[3]) + return (*geom[0], *geom[1]) + + +async def get_documents(files: List[UploadFile]) -> Tuple[List[np.ndarray], List[str]]: # pragma: no cover + """Convert a list of UploadFile objects to lists of numpy arrays and their corresponding filenames + + Args: + ---- + files: list of UploadFile objects + + Returns: + ------- + Tuple[List[np.ndarray], List[str]]: list of numpy arrays and their corresponding filenames + + """ + filenames = [] + docs = [] + for file in files: + mime_type = file.content_type + if mime_type in ["image/jpeg", "image/png"]: + docs.extend(DocumentFile.from_images([await file.read()])) + filenames.append(file.filename or "") + elif mime_type == "application/pdf": + pdf_content = DocumentFile.from_pdf(await file.read()) + docs.extend(pdf_content) + filenames.extend([file.filename] * len(pdf_content) or [""] * len(pdf_content)) + else: + raise ValueError(f"Unsupported file format: {mime_type} for file {file.filename}") + + return docs, filenames diff --git a/api/app/vision.py b/api/app/vision.py index c3e5f7560a..005c8d1548 100644 --- a/api/app/vision.py +++ b/api/app/vision.py @@ -3,15 +3,45 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. + import tensorflow as tf gpu_devices = tf.config.experimental.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) +from typing import Callable, Union + from doctr.models import kie_predictor, ocr_predictor -predictor = ocr_predictor(pretrained=True) -det_predictor = predictor.det_predictor -reco_predictor = predictor.reco_predictor -kie_predictor = kie_predictor(pretrained=True) +from .schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn + + +def init_predictor(request: Union[KIEIn, OCRIn, RecognitionIn, DetectionIn]) -> Callable: + """Initialize the predictor based on the request + + Args: + ---- + request: input request + + Returns: + ------- + Callable: the predictor + """ + params = request.model_dump() + bin_thresh = params.pop("bin_thresh", None) + box_thresh = params.pop("box_thresh", None) + if isinstance(request, (OCRIn, RecognitionIn, DetectionIn)): + predictor = ocr_predictor(pretrained=True, **params) + predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh + if isinstance(request, DetectionIn): + return predictor.det_predictor + elif isinstance(request, RecognitionIn): + return predictor.reco_predictor + return predictor + elif isinstance(request, KIEIn): + predictor = kie_predictor(pretrained=True, **params) + predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh + return predictor diff --git a/api/docker-compose.yml b/api/docker-compose.yml index cc85ef841b..4140ed9cbb 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.7' +version: '3.8' services: web: diff --git a/api/pyproject.toml b/api/pyproject.toml index cb76a5c648..9824f0442a 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -11,7 +11,6 @@ license = "Apache-2.0" [tool.poetry.dependencies] python = ">=3.9,<3.12" -tensorflow = ">=2.11.0,<2.16.0" # cf. https://github.com/mindee/doctr/pull/1461 python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['tf'], branch = "main" } # Fastapi: minimum version required to avoid pydantic error # cf. https://github.com/tiangolo/fastapi/issues/4168 diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 5fb7340c18..41872b47ec 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -17,8 +17,234 @@ def mock_detection_image(tmpdir_factory): return requests.get(url).content +@pytest_asyncio.fixture(scope="session") +def mock_txt_file(tmpdir_factory): + txt_file = tmpdir_factory.mktemp("data").join("mock.txt") + txt_file.write("mock text") + return txt_file.read("rb") + + @pytest_asyncio.fixture(scope="function") async def test_app_asyncio(): # for httpx>=20, follow_redirects=True (cf. https://github.com/encode/httpx/releases/tag/0.20.0) async with AsyncClient(app=app, base_url="http://test", follow_redirects=True) as ac: yield ac # testing happens here + + +@pytest_asyncio.fixture(scope="function") +def mock_detection_response(): + return { + "box": { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "geometries": [ + [0.724609375, 0.1787109375, 0.7900390625, 0.2080078125], + [0.6748046875, 0.1796875, 0.7314453125, 0.20703125], + ], + }, + "poly": { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "geometries": [ + [ + 0.7873152494430542, + 0.17740710079669952, + 0.7884310483932495, + 0.20474515855312347, + 0.7244035005569458, + 0.20735852420330048, + 0.7232877016067505, + 0.18002046644687653, + ], + [ + 0.7286394834518433, + 0.17740298807621002, + 0.7298480272293091, + 0.2027825564146042, + 0.6746810674667358, + 0.20540954172611237, + 0.67347252368927, + 0.1800299733877182, + ], + ], + }, + } + + +@pytest_asyncio.fixture(scope="function") +def mock_kie_response(): + return { + "box": { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "orientation": {"value": None, "confidence": None}, + "language": {"value": None, "confidence": None}, + "dimensions": [2339, 1654], + "predictions": [ + { + "class_name": "words", + "items": [ + { + "value": "Hello", + "geometry": [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125], + "confidence": 1, + }, + { + "value": "world!", + "geometry": [0.8176307908857315, 0.1787109375, 0.9101580212741838, 0.2080078125], + "confidence": 1, + }, + ], + } + ], + }, + "poly": { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "orientation": {"value": None, "confidence": None}, + "language": {"value": None, "confidence": None}, + "dimensions": [2339, 1654], + "predictions": [ + { + "class_name": "words", + "items": [ + { + "value": "Hello", + "geometry": [ + 0.7453157305717468, + 0.1800299733877182, + 0.8233299851417542, + 0.17740298807621002, + 0.8250390291213989, + 0.2027825564146042, + 0.7470247745513916, + 0.20540954172611237, + ], + "confidence": 0.99, + }, + { + "value": "world!", + "geometry": [ + 0.8157618045806885, + 0.18002046644687653, + 0.9063061475753784, + 0.17740710079669952, + 0.9078840017318726, + 0.20474515855312347, + 0.8173396587371826, + 0.20735852420330048, + ], + "confidence": 1, + }, + ], + } + ], + }, + } + + +@pytest_asyncio.fixture(scope="function") +def mock_ocr_response(): + return { + "box": { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "orientation": {"value": None, "confidence": None}, + "language": {"value": None, "confidence": None}, + "dimensions": [2339, 1654], + "items": [ + { + "blocks": [ + { + "geometry": [0.7471996155154171, 0.1787109375, 0.9101580212741838, 0.2080078125], + "lines": [ + { + "geometry": [0.7471996155154171, 0.1787109375, 0.9101580212741838, 0.2080078125], + "words": [ + { + "value": "Hello", + "geometry": [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125], + "confidence": 1, + }, + { + "value": "world!", + "geometry": [ + 0.8176307908857315, + 0.1787109375, + 0.9101580212741838, + 0.2080078125, + ], + "confidence": 1, + }, + ], + } + ], + } + ] + } + ], + }, + "poly": { + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "orientation": {"value": None, "confidence": None}, + "language": {"value": None, "confidence": None}, + "dimensions": [2339, 1654], + "items": [ + { + "blocks": [ + { + "geometry": [ + 0.7451040148735046, + 0.17927837371826172, + 0.9062581658363342, + 0.17407986521720886, + 0.9072266221046448, + 0.2041015625, + 0.7460724711418152, + 0.20930007100105286, + ], + "lines": [ + { + "geometry": [ + 0.7451040148735046, + 0.17927837371826172, + 0.9062581658363342, + 0.17407986521720886, + 0.9072266221046448, + 0.2041015625, + 0.7460724711418152, + 0.20930007100105286, + ], + "words": [ + { + "value": "Hello", + "geometry": [ + 0.7453157305717468, + 0.1800299733877182, + 0.8233299851417542, + 0.17740298807621002, + 0.8250390291213989, + 0.2027825564146042, + 0.7470247745513916, + 0.20540954172611237, + ], + "confidence": 0.99, + }, + { + "value": "world!", + "geometry": [ + 0.8157618045806885, + 0.18002046644687653, + 0.9063061475753784, + 0.17740710079669952, + 0.9078840017318726, + 0.20474515855312347, + 0.8173396587371826, + 0.20735852420330048, + ], + "confidence": 1, + }, + ], + } + ], + } + ] + } + ], + }, + } diff --git a/api/tests/routes/test_detection.py b/api/tests/routes/test_detection.py index db3c17c5e7..51672fd962 100644 --- a/api/tests/routes/test_detection.py +++ b/api/tests/routes/test_detection.py @@ -1,24 +1,58 @@ import numpy as np import pytest -from scipy.optimize import linear_sum_assignment -from doctr.utils.metrics import box_iou + +def common_test(json_response, expected_response): + assert isinstance(json_response, list) and len(json_response) == 2 + first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice + + assert isinstance(first_pred["name"], str) + np.testing.assert_allclose(first_pred["geometries"], expected_response["geometries"], rtol=1e-2) @pytest.mark.asyncio -async def test_text_detection(test_app_asyncio, mock_detection_image): - response = await test_app_asyncio.post("/detection", files={"file": mock_detection_image}) +async def test_text_detection_box(test_app_asyncio, mock_detection_image, mock_detection_response): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50"} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/detection", params=params, files=files, headers=headers) assert response.status_code == 200 json_response = response.json() - gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) - gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 - gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 - - # Check that IoU with GT if reasonable - assert isinstance(json_response, list) and len(json_response) == gt_boxes.shape[0] - pred_boxes = np.array([elt["box"] for elt in json_response]) - iou_mat = box_iou(gt_boxes, pred_boxes) - gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) - is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 - assert gt_idxs[is_kept].shape[0] == gt_boxes.shape[0] + expected_box_response = mock_detection_response["box"] + common_test(json_response, expected_box_response) + + +@pytest.mark.asyncio +async def test_text_detection_poly(test_app_asyncio, mock_detection_image, mock_detection_response): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "assume_straight_pages": False} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/detection", params=params, files=files, headers=headers) + assert response.status_code == 200 + json_response = response.json() + + expected_poly_response = mock_detection_response["poly"] + common_test(json_response, expected_poly_response) + + +@pytest.mark.asyncio +async def test_text_detection_invalid_file(test_app_asyncio, mock_txt_file): + headers = { + "accept": "application/json", + } + files = [ + ("files", ("test.txt", mock_txt_file)), + ] + response = await test_app_asyncio.post("/detection", files=files, headers=headers) + assert response.status_code == 400 diff --git a/api/tests/routes/test_kie.py b/api/tests/routes/test_kie.py index cf3c5678a5..36ca4b5b62 100644 --- a/api/tests/routes/test_kie.py +++ b/api/tests/routes/test_kie.py @@ -1,28 +1,74 @@ import numpy as np import pytest -from scipy.optimize import linear_sum_assignment -from doctr.utils.metrics import box_iou + +def common_test(json_response, expected_response): + first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice + assert isinstance(first_pred["name"], str) + assert ( + isinstance(first_pred["dimensions"], (tuple, list)) + and len(first_pred["dimensions"]) == 2 + and all(isinstance(dim, int) for dim in first_pred["dimensions"]) + ) + assert isinstance(first_pred["predictions"], list) + assert isinstance(expected_response["predictions"], list) + + for pred, expected_pred in zip(first_pred["predictions"], expected_response["predictions"]): + assert pred["class_name"] == expected_pred["class_name"] + assert isinstance(pred["items"], list) + assert isinstance(expected_pred["items"], list) + + for pred_item, expected_pred_item in zip(pred["items"], expected_pred["items"]): + assert isinstance(pred_item["value"], str) and pred_item["value"] == expected_pred_item["value"] + assert isinstance(pred_item["confidence"], (int, float)) + np.testing.assert_allclose(pred_item["geometry"], expected_pred_item["geometry"], rtol=1e-2) + + +@pytest.mark.asyncio +async def test_kie_box(test_app_asyncio, mock_detection_image, mock_kie_response): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn"} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/kie", params=params, files=files, headers=headers) + assert response.status_code == 200 + json_response = response.json() + + expected_box_response = mock_kie_response["box"] + assert isinstance(json_response, list) and len(json_response) == 2 + common_test(json_response, expected_box_response) @pytest.mark.asyncio -async def test_perform_kie(test_app_asyncio, mock_detection_image): - response = await test_app_asyncio.post("/kie", files={"file": mock_detection_image}) +async def test_kie_poly(test_app_asyncio, mock_detection_image, mock_kie_response): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "assume_straight_pages": False} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/kie", params=params, files=files, headers=headers) assert response.status_code == 200 json_response = response.json() - gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) - gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 - gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 - gt_labels = ["Hello", "world!"] - - # Check that IoU with GT if reasonable - assert isinstance(json_response, dict) and len(list(json_response.values())[0]) == gt_boxes.shape[0] - pred_boxes = np.array([elt["box"] for json_out in json_response.values() for elt in json_out]) - pred_labels = np.array([elt["value"] for json_out in json_response.values() for elt in json_out]) - iou_mat = box_iou(gt_boxes, pred_boxes) - gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) - is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 - gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept] - assert gt_idxs.shape[0] == gt_boxes.shape[0] - assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs)) + expected_poly_response = mock_kie_response["poly"] + assert isinstance(json_response, list) and len(json_response) == 2 + common_test(json_response, expected_poly_response) + + +@pytest.mark.asyncio +async def test_kie_invalid_file(test_app_asyncio, mock_txt_file): + headers = { + "accept": "application/json", + } + files = [ + ("files", ("test.txt", mock_txt_file)), + ] + response = await test_app_asyncio.post("/kie", files=files, headers=headers) + assert response.status_code == 400 diff --git a/api/tests/routes/test_ocr.py b/api/tests/routes/test_ocr.py index 3d7b3df3b9..c702084447 100644 --- a/api/tests/routes/test_ocr.py +++ b/api/tests/routes/test_ocr.py @@ -1,28 +1,72 @@ import numpy as np import pytest -from scipy.optimize import linear_sum_assignment -from doctr.utils.metrics import box_iou + +def common_test(json_response, expected_response): + first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice + + assert isinstance(first_pred["name"], str) + assert ( + isinstance(first_pred["dimensions"], (tuple, list)) + and len(first_pred["dimensions"]) == 2 + and all(isinstance(dim, int) for dim in first_pred["dimensions"]) + ) + for item, expected_item in zip(first_pred["items"], expected_response["items"]): + for block, expected_block in zip(item["blocks"], expected_item["blocks"]): + np.testing.assert_allclose(block["geometry"], expected_block["geometry"], rtol=1e-2) + for line, expected_line in zip(block["lines"], expected_block["lines"]): + np.testing.assert_allclose(line["geometry"], expected_line["geometry"], rtol=1e-2) + for word, expected_word in zip(line["words"], expected_line["words"]): + np.testing.assert_allclose(word["geometry"], expected_word["geometry"], rtol=1e-2) + assert isinstance(word["value"], str) and word["value"] == expected_word["value"] + assert isinstance(word["confidence"], (int, float)) @pytest.mark.asyncio -async def test_perform_ocr(test_app_asyncio, mock_detection_image): - response = await test_app_asyncio.post("/ocr", files={"file": mock_detection_image}) +async def test_ocr_box(test_app_asyncio, mock_detection_image, mock_ocr_response): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn"} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/ocr", params=params, files=files, headers=headers) assert response.status_code == 200 json_response = response.json() - gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) - gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 - gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 - gt_labels = ["Hello", "world!"] - - # Check that IoU with GT if reasonable - assert isinstance(json_response, list) and len(json_response) == gt_boxes.shape[0] - pred_boxes = np.array([elt["box"] for elt in json_response]) - pred_labels = np.array([elt["value"] for elt in json_response]) - iou_mat = box_iou(gt_boxes, pred_boxes) - gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) - is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 - gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept] - assert gt_idxs.shape[0] == gt_boxes.shape[0] - assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs)) + expected_box_response = mock_ocr_response["box"] + assert isinstance(json_response, list) and len(json_response) == 2 + common_test(json_response, expected_box_response) + + +@pytest.mark.asyncio +async def test_ocr_poly(test_app_asyncio, mock_detection_image, mock_ocr_response): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "assume_straight_pages": False} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/ocr", params=params, files=files, headers=headers) + assert response.status_code == 200 + json_response = response.json() + + expected_poly_response = mock_ocr_response["poly"] + assert isinstance(json_response, list) and len(json_response) == 2 + common_test(json_response, expected_poly_response) + + +@pytest.mark.asyncio +async def test_ocr_invalid_file(test_app_asyncio, mock_txt_file): + headers = { + "accept": "application/json", + } + files = [ + ("files", ("test.txt", mock_txt_file)), + ] + response = await test_app_asyncio.post("/ocr", files=files, headers=headers) + assert response.status_code == 400 diff --git a/api/tests/routes/test_recognition.py b/api/tests/routes/test_recognition.py index 95467758a8..61c6561133 100644 --- a/api/tests/routes/test_recognition.py +++ b/api/tests/routes/test_recognition.py @@ -2,8 +2,29 @@ @pytest.mark.asyncio -async def test_text_recognition(test_app_asyncio, mock_recognition_image): - response = await test_app_asyncio.post("/recognition", files={"file": mock_recognition_image}) +async def test_text_recognition(test_app_asyncio, mock_recognition_image, mock_txt_file): + headers = { + "accept": "application/json", + } + params = {"reco_arch": "crnn_vgg16_bn"} + files = [ + ("files", ("test.jpg", mock_recognition_image, "image/jpeg")), + ("files", ("test2.jpg", mock_recognition_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/recognition", params=params, files=files, headers=headers) assert response.status_code == 200 + json_response = response.json() + assert isinstance(json_response, list) and len(json_response) == 2 + for item in json_response: + assert isinstance(item["name"], str) + assert isinstance(item["value"], str) and item["value"] == "invite" + assert isinstance(item["confidence"], (int, float)) and item["confidence"] >= 0.8 - assert response.json() == {"value": "invite"} + headers = { + "accept": "application/json", + } + files = [ + ("files", ("test.txt", mock_txt_file)), + ] + response = await test_app_asyncio.post("/recognition", files=files, headers=headers) + assert response.status_code == 400 diff --git a/api/tests/utils/test_utils.py b/api/tests/utils/test_utils.py new file mode 100644 index 0000000000..09b3a2eb7a --- /dev/null +++ b/api/tests/utils/test_utils.py @@ -0,0 +1,9 @@ +from app.utils import resolve_geometry + + +def test_resolve_geometry(): + dummy_box = [(0.0, 0.0), (1.0, 0.0)] + dummy_polygon = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] + + assert resolve_geometry(dummy_box) == (0.0, 0.0, 1.0, 0.0) + assert resolve_geometry(dummy_polygon) == (0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0) diff --git a/api/tests/utils/test_vision.py b/api/tests/utils/test_vision.py new file mode 100644 index 0000000000..4375322a65 --- /dev/null +++ b/api/tests/utils/test_vision.py @@ -0,0 +1,13 @@ +from app.schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn +from app.vision import init_predictor +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.kie_predictor import KIEPredictor +from doctr.models.predictor import OCRPredictor +from doctr.models.recognition.predictor import RecognitionPredictor + + +def test_vision(): + assert isinstance(init_predictor(OCRIn()), OCRPredictor) + assert isinstance(init_predictor(DetectionIn()), DetectionPredictor) + assert isinstance(init_predictor(RecognitionIn()), RecognitionPredictor) + assert isinstance(init_predictor(KIEIn()), KIEPredictor) diff --git a/doctr/datasets/generator/base.py b/doctr/datasets/generator/base.py index 71a09abd85..424f59563d 100644 --- a/doctr/datasets/generator/base.py +++ b/doctr/datasets/generator/base.py @@ -20,7 +20,7 @@ def synthesize_text_img( font_family: Optional[str] = None, background_color: Optional[Tuple[int, int, int]] = None, text_color: Optional[Tuple[int, int, int]] = None, -) -> Image: +) -> Image.Image: """Generate a synthetic text image Args: @@ -81,7 +81,7 @@ def __init__( self._data: List[Image.Image] = [] if cache_samples: self._data = [ - (synthesize_text_img(char, font_family=font), idx) + (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc] for idx, char in enumerate(self.vocab) for font in self.font_family ] @@ -93,7 +93,7 @@ def _read_sample(self, index: int) -> Tuple[Any, int]: # Samples are already cached if len(self._data) > 0: idx = index % len(self._data) - pil_img, target = self._data[idx] + pil_img, target = self._data[idx] # type: ignore[misc] else: target = index % len(self.vocab) pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family)) @@ -132,7 +132,8 @@ def __init__( if cache_samples: _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)] self._data = [ - (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) for text in _words + (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc] + for text in _words ] def _generate_string(self, min_chars: int, max_chars: int) -> str: @@ -145,7 +146,7 @@ def __len__(self) -> int: def _read_sample(self, index: int) -> Tuple[Any, str]: # Samples are already cached if len(self._data) > 0: - pil_img, target = self._data[index] + pil_img, target = self._data[index] # type: ignore[misc] else: target = self._generate_string(*self.wordlen_range) pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family)) diff --git a/doctr/io/image/pytorch.py b/doctr/io/image/pytorch.py index 26e3e76f95..2e8450e840 100644 --- a/doctr/io/image/pytorch.py +++ b/doctr/io/image/pytorch.py @@ -16,7 +16,7 @@ __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] -def tensor_from_pil(pil_img: Image, dtype: torch.dtype = torch.float32) -> torch.Tensor: +def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Convert a PIL Image to a PyTorch tensor Args: diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py index dbfc55b4be..28fb2fadd5 100644 --- a/doctr/io/image/tensorflow.py +++ b/doctr/io/image/tensorflow.py @@ -15,7 +15,7 @@ __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] -def tensor_from_pil(pil_img: Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: +def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: """Convert a PIL Image to a TensorFlow tensor Args: