From 8e95e2fb4bb708eb0b90ce887e2306de3e057b8b Mon Sep 17 00:00:00 2001 From: T2K-Felix <125863421+felixT2K@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:46:10 +0200 Subject: [PATCH 01/18] [misc] post release 0.9.1 (#1689) --- .conda/meta.yaml | 2 +- api/pyproject.toml | 2 +- docs/build.sh | 3 ++- docs/source/_static/js/custom.js | 5 +++-- docs/source/changelog.rst | 4 ++++ setup.py | 2 +- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 042cc9230b..fcac492132 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -1,7 +1,7 @@ {% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %} {% set project = pyproject.get('project') %} {% set urls = pyproject.get('project', {}).get('urls') %} -{% set version = environ.get('BUILD_VERSION', '0.9.0a0') %} +{% set version = environ.get('BUILD_VERSION', '0.9.1a0') %} package: name: {{ project.get('name') }} diff --git a/api/pyproject.toml b/api/pyproject.toml index 9824f0442a..459c09b2ff 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "doctr-api" -version = "0.9.0a0" +version = "0.9.1a0" description = "Backend template for your OCR API with docTR" authors = ["Mindee "] license = "Apache-2.0" diff --git a/docs/build.sh b/docs/build.sh index 62c40ebbb8..1366b59c36 100644 --- a/docs/build.sh +++ b/docs/build.sh @@ -53,5 +53,6 @@ deploy_doc "9d03085" v0.5.1 deploy_doc "dcbb21f" v0.6.0 deploy_doc "75bddfc" v0.7.0 deploy_doc "67d1087" v0.8.0 -deploy_doc "62d94ff" # v0.8.1 Latest stable release +deploy_doc "62d94ff" v0.8.1 +deploy_doc "894eafd" # v0.9.0 Latest stable release rm -rf _build _static _conf.py diff --git a/docs/source/_static/js/custom.js b/docs/source/_static/js/custom.js index 50a0653c5e..2b96d1a7f6 100644 --- a/docs/source/_static/js/custom.js +++ b/docs/source/_static/js/custom.js @@ -3,11 +3,12 @@ // These two things need to be updated at each release for the version selector. // Last stable version -const stableVersion = "v0.8.1" +const stableVersion = "v0.9.0" // Dictionary doc folder to label. The last stable version should have an empty key. const versionMapping = { "latest": "latest", - "": "v0.8.1 (stable)", + "": "v0.9.0 (stable)", + "v0.8.1": "v0.8.1", "v0.8.0": "v0.8.0", "v0.7.0": "v0.7.0", "v0.6.0": "v0.6.0", diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 5a5ffc3f74..8ae242b712 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,10 @@ Changelog ========= +v0.9.0 (2024-08-08) +------------------- +Release note: `v0.9.0 `_ + v0.8.1 (2024-03-04) ------------------- Release note: `v0.8.1 `_ diff --git a/setup.py b/setup.py index 2347b4ddd6..f45f3f157d 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import setup PKG_NAME = "python-doctr" -VERSION = os.getenv("BUILD_VERSION", "0.9.0a0") +VERSION = os.getenv("BUILD_VERSION", "0.9.1a0") if __name__ == "__main__": From d7f453329f583798c0d2774f343e29ab24450a4d Mon Sep 17 00:00:00 2001 From: MinhChien <45474685+MinhChien9@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:21:42 +0700 Subject: [PATCH 02/18] [Datasets] Correct Vietnamese letters (#1693) --- docs/source/modules/datasets.rst | 4 ++-- doctr/datasets/vocabs.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/modules/datasets.rst b/docs/source/modules/datasets.rst index 6f6f2a530d..b4690247c1 100644 --- a/docs/source/modules/datasets.rst +++ b/docs/source/modules/datasets.rst @@ -152,8 +152,8 @@ of vocabs. - 106 - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿åäöÅÄÖ * - vietnamese - - 234 - - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ + - 236 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿áàảạãăắằẳẵặâấầẩẫậđéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬĐÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ * - hebrew - 123 - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿אבגדהוזחטיכלמנסעפצקרשת₪ diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py index f682560d1d..c17a6a5f0f 100644 --- a/doctr/datasets/vocabs.py +++ b/doctr/datasets/vocabs.py @@ -53,8 +53,8 @@ VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ" VOCABS["vietnamese"] = ( VOCABS["english"] - + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ" - + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ" + + "áàảạãăắằẳẵặâấầẩẫậđéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ" + + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬĐÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ" ) VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪" VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"] From 06bce5113beb760c947d3876ed41e1aa1972d869 Mon Sep 17 00:00:00 2001 From: Koen Farell Date: Wed, 21 Aug 2024 18:55:54 +0300 Subject: [PATCH 03/18] feat: added ukrainian vocab (#1700) --- docs/source/modules/datasets.rst | 6 ++++++ doctr/datasets/vocabs.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/docs/source/modules/datasets.rst b/docs/source/modules/datasets.rst index b4690247c1..872212a121 100644 --- a/docs/source/modules/datasets.rst +++ b/docs/source/modules/datasets.rst @@ -94,6 +94,9 @@ of vocabs. * - arabic_letters - 37 - ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي + * - generic_cyrillic_letters + - 58 + - абвгдежзийклмнопрстуфхцчшщьюяАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЬЮЯ * - persian_letters - 5 - پچڢڤگ @@ -151,6 +154,9 @@ of vocabs. * - swedish - 106 - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿åäöÅÄÖ + * - ukrainian + - 115 + - абвгдежзийклмнопрстуфхцчшщьюяАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЬЮЯ0123456789!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿ґіїєҐІЇЄ₴ * - vietnamese - 236 - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿áàảạãăắằẳẵặâấầẩẫậđéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬĐÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py index c17a6a5f0f..91c5af7950 100644 --- a/doctr/datasets/vocabs.py +++ b/doctr/datasets/vocabs.py @@ -25,6 +25,7 @@ "hindi_punctuation": "।,?!:्ॐ॰॥॰", "bangla_letters": "অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃেৈোৌ্ৎংঃঁ", "bangla_digits": "০১২৩৪৫৬৭৮৯", + "generic_cyrillic_letters": "абвгдежзийклмнопрстуфхцчшщьюяАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЬЮЯ", } VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"] @@ -59,6 +60,7 @@ VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪" VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"] VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"] +VOCABS["ukrainian"] = VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴" VOCABS["multilingual"] = "".join( dict.fromkeys( VOCABS["french"] From 4434213bf3554b6ceb4bae104092ef06bfece855 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 26 Aug 2024 10:54:04 +0200 Subject: [PATCH 04/18] fix straighten pages (#1697) --- doctr/models/builder.py | 4 ++-- doctr/models/kie_predictor/pytorch.py | 3 +++ doctr/models/kie_predictor/tensorflow.py | 3 +++ doctr/models/modules/vision_transformer/pytorch.py | 2 +- doctr/models/modules/vision_transformer/tensorflow.py | 2 +- doctr/models/predictor/base.py | 4 ++-- doctr/models/predictor/pytorch.py | 3 +++ doctr/models/predictor/tensorflow.py | 3 +++ 8 files changed, 18 insertions(+), 6 deletions(-) diff --git a/doctr/models/builder.py b/doctr/models/builder.py index 4773404dec..8dfcafcc9d 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -266,7 +266,7 @@ def _build_blocks( Line([ Word( *word_preds[idx], - tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type] + tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type] float(objectness_scores[idx]), crop_orientations[idx], ) @@ -500,7 +500,7 @@ def _build_blocks( # type: ignore[override] Prediction( value=word_preds[idx][0], confidence=word_preds[idx][1], - geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type] + geometry=tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type] objectness_score=float(objectness_scores[idx]), crop_orientation=crop_orientations[idx], ) diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 5c665d1800..4bcedc7064 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -99,6 +99,9 @@ def forward( origin_pages_orientations = None if self.straighten_pages: pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore + # update page shapes after straightening + origin_page_shapes = [page.shape[:2] for page in pages] + # Forward again to get predictions on straight pages loc_preds = self.det_predictor(pages, **kwargs) diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index 085f3aecbe..d9d765bbe6 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -99,6 +99,9 @@ def __call__( origin_pages_orientations = None if self.straighten_pages: pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) + # update page shapes after straightening + origin_page_shapes = [page.shape[:2] for page in pages] + # Forward again to get predictions on straight pages loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment] diff --git a/doctr/models/modules/vision_transformer/pytorch.py b/doctr/models/modules/vision_transformer/pytorch.py index 4ff07ed4ff..c13edf234b 100644 --- a/doctr/models/modules/vision_transformer/pytorch.py +++ b/doctr/models/modules/vision_transformer/pytorch.py @@ -20,7 +20,7 @@ def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size channels, height, width = input_shape self.patch_size = patch_size self.interpolate = True if patch_size[0] == patch_size[1] else False - self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)]) + self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size)) self.num_patches = self.grid_size[0] * self.grid_size[1] self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py index a78f0da3fb..8386172eb1 100644 --- a/doctr/models/modules/vision_transformer/tensorflow.py +++ b/doctr/models/modules/vision_transformer/tensorflow.py @@ -22,7 +22,7 @@ def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size height, width, _ = input_shape self.patch_size = patch_size self.interpolate = True if patch_size[0] == patch_size[1] else False - self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)]) + self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size)) self.num_patches = self.grid_size[0] * self.grid_size[1] self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token") diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 965abd8143..0469b32ea3 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -101,8 +101,8 @@ def _straighten_pages( ] ) return [ - # We exapnd if the page is wider than tall and the angle is 90 or -90 - rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90) + # expand if height and width are not equal + rotate_image(page, angle, expand=page.shape[0] != page.shape[1]) for page, angle in zip(pages, origin_pages_orientations) ] diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 4b66365918..7cbf383a06 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -97,6 +97,9 @@ def forward( origin_pages_orientations = None if self.straighten_pages: pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore + # update page shapes after straightening + origin_page_shapes = [page.shape[:2] for page in pages] + # Forward again to get predictions on straight pages loc_preds = self.det_predictor(pages, **kwargs) diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 7c895d5e2a..f736614879 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -97,6 +97,9 @@ def __call__( origin_pages_orientations = None if self.straighten_pages: pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) + # update page shapes after straightening + origin_page_shapes = [page.shape[:2] for page in pages] + # forward again to get predictions on straight pages loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment] From 9045dcfc9c5c837b06fcda8e802f7cf1d95bd18c Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Thu, 29 Aug 2024 04:47:33 +0200 Subject: [PATCH 05/18] [orientation] Enable usage of custom trained orientation models (#1708) --- .../using_doctr/custom_models_training.rst | 73 ++++++++++++++++++- doctr/datasets/vocabs.py | 4 +- .../classification/mobilenet/pytorch.py | 2 + doctr/models/classification/zoo.py | 26 ++++--- doctr/models/factory/hub.py | 4 +- .../pytorch/test_models_classification_pt.py | 18 +++++ tests/pytorch/test_models_zoo_pt.py | 38 ++++++++++ .../test_models_classification_tf.py | 18 +++++ tests/tensorflow/test_models_zoo_tf.py | 38 ++++++++++ 9 files changed, 207 insertions(+), 14 deletions(-) diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index 6214dae2dc..ecf88d8116 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -1,7 +1,7 @@ Train your own model ==================== -If the pretrained models don't meet your specific needs, you have the option to train your own model using the doctr library. +If the pretrained models don't meet your specific needs, you have the option to train your own model using the docTR library. For details on the training process and the necessary data and data format, refer to the following links: - `detection `_ @@ -203,3 +203,74 @@ Load a model with customized Preprocessor: ) predictor = OCRPredictor(det_predictor, reco_predictor) + +Custom orientation classification models +---------------------------------------- + +If you work with rotated documents and make use of the orientation classification feature by passing one of the following arguments: + +* `assume_straight_pages=False` +* `detect_orientation=True` +* `straigten_pages=True` + +You can train your own orientation classification model using the docTR library. For details on the training process and the necessary data and data format, refer to the following link: + +- `orientation `_ + +**NOTE**: Currently we support only `mobilenet_v3_small` models for crop and page orientation classification. + +Loading your custom trained orientation classification model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. tabs:: + + .. tab:: TensorFlow + + .. code:: python3 + + from doctr.io import DocumentFile + from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation + from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor + + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) + custom_page_orientation_model.load_weights("/weights") + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) + custom_crop_orientation_model.load_weights("/weights") + + predictor = ocr_predictor( + pretrained=True, + assume_straight_pages=False, + straighten_pages=True, + detect_orientation=True, + ) + + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + .. tab:: PyTorch + + .. code:: python3 + + import torch + from doctr.io import DocumentFile + from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation + from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor + + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) + page_params = torch.load('', map_location="cpu") + custom_page_orientation_model.load_state_dict(page_params) + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) + crop_params = torch.load('', map_location="cpu") + custom_crop_orientation_model.load_state_dict(crop_params) + + predictor = ocr_predictor( + pretrained=True, + assume_straight_pages=False, + straighten_pages=True, + detect_orientation=True, + ) + + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py index 91c5af7950..94942d58e3 100644 --- a/doctr/datasets/vocabs.py +++ b/doctr/datasets/vocabs.py @@ -60,7 +60,9 @@ VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪" VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"] VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"] -VOCABS["ukrainian"] = VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴" +VOCABS["ukrainian"] = ( + VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴" +) VOCABS["multilingual"] = "".join( dict.fromkeys( VOCABS["french"] diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py index 615664854d..18470fdf11 100644 --- a/doctr/models/classification/mobilenet/pytorch.py +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -9,12 +9,14 @@ from typing import Any, Dict, List, Optional from torchvision.models import mobilenetv3 +from torchvision.models.mobilenetv3 import MobileNetV3 from doctr.datasets import VOCABS from ...utils import load_pretrained_params __all__ = [ + "MobileNetV3", "mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9368bb225d..fccd5b5979 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -34,15 +34,21 @@ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] -def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor: - if arch not in ORIENTATION_ARCHS: - raise ValueError(f"unknown architecture '{arch}'") +def _orientation_predictor(arch: Any, pretrained: bool, model_type: str, **kwargs: Any) -> OrientationPredictor: + if isinstance(arch, str): + if arch not in ORIENTATION_ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + # Load directly classifier from backbone + _model = classification.__dict__[arch](pretrained=pretrained) + else: + if not isinstance(arch, classification.MobileNetV3): + raise ValueError(f"unknown architecture: {type(arch)}") + _model = arch - # Load directly classifier from backbone - _model = classification.__dict__[arch](pretrained=pretrained) kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) - kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model @@ -51,7 +57,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient def crop_orientation_predictor( - arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any ) -> OrientationPredictor: """Crop orientation classification architecture. @@ -71,11 +77,11 @@ def crop_orientation_predictor( ------- OrientationPredictor """ - return _orientation_predictor(arch, pretrained, **kwargs) + return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs) def page_orientation_predictor( - arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any ) -> OrientationPredictor: """Page orientation classification architecture. @@ -95,4 +101,4 @@ def page_orientation_predictor( ------- OrientationPredictor """ - return _orientation_predictor(arch, pretrained, **kwargs) + return _orientation_predictor(arch, pretrained, model_type="page", **kwargs) diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index a6c3f89322..41cd91579a 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -33,7 +33,7 @@ AVAILABLE_ARCHS = { - "classification": models.classification.zoo.ARCHS, + "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS, "detection": models.detection.zoo.ARCHS, "recognition": models.recognition.zoo.ARCHS, } @@ -174,7 +174,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False) - repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True) + repo = Repository(local_dir=local_cache_dir, clone_from=repo_url) with repo.commit(commit_message): _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index d2dbe5087a..f35a1ac9de 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -134,6 +134,15 @@ def test_crop_orientation_model(mock_text_box): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.crop_orientation_predictor( + classification.mobilenet_v3_small_crop_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True)) + def test_page_orientation_model(mock_payslip): text_box_0 = cv2.imread(mock_payslip) @@ -147,6 +156,15 @@ def test_page_orientation_model(mock_payslip): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.page_orientation_predictor( + classification.mobilenet_v3_small_page_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True)) + @pytest.mark.parametrize( "arch_name, input_shape, output_size", diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 0cac9724ee..9be66edd7b 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -7,6 +7,8 @@ from doctr.io import Document, DocumentFile from doctr.io.elements import KIEDocument from doctr.models import detection, recognition +from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation +from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor @@ -85,6 +87,24 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa orientation = 0 assert out.pages[0].orientation["value"] == orientation + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_ocr_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) @@ -209,6 +229,24 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa orientation = 0 assert out.pages[0].orientation["value"] == orientation + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_kie_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 8b2c720328..77eb8253ca 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -113,6 +113,15 @@ def test_crop_orientation_model(mock_text_box): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.crop_orientation_predictor( + classification.mobilenet_v3_small_crop_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True)) + def test_page_orientation_model(mock_payslip): text_box_0 = cv2.imread(mock_payslip) @@ -126,6 +135,15 @@ def test_page_orientation_model(mock_payslip): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test custom model loading + classifier = classification.page_orientation_predictor( + classification.mobilenet_v3_small_page_orientation(pretrained=True) + ) + assert isinstance(classifier, OrientationPredictor) + + with pytest.raises(ValueError): + _ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True)) + # temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue) # ref.: https://github.com/mindee/doctr/pull/1201 diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index f20cb21f5c..4b7e606563 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -6,6 +6,8 @@ from doctr.io import Document, DocumentFile from doctr.io.elements import KIEDocument from doctr.models import detection, recognition +from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation +from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor @@ -84,6 +86,24 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa language = "unknown" assert out.pages[0].language["value"] == language + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_ocr_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) @@ -207,6 +227,24 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa language = "unknown" assert out.pages[0].language["value"] == language + # Test with custom orientation models + custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True) + custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True) + + if assume_straight_pages: + if predictor.detect_orientation or predictor.straighten_pages: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + else: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + + out = predictor(doc) + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + def test_trained_kie_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) From 420ab32501ca1ff7d76c5c80e37a3a8ad6a1f89c Mon Sep 17 00:00:00 2001 From: Milos Acimovic Date: Fri, 27 Sep 2024 12:20:42 +0200 Subject: [PATCH 06/18] [orientation] Allow disable of page and crop orientation (#1735) --- api/app/schemas.py | 2 + demo/app.py | 16 ++- demo/backend/pytorch.py | 6 + demo/backend/tensorflow.py | 6 + docs/source/using_doctr/using_models.rst | 34 +++++- .../classification/predictor/pytorch.py | 18 +-- .../classification/predictor/tensorflow.py | 14 ++- doctr/models/classification/zoo.py | 8 +- doctr/models/kie_predictor/base.py | 4 + doctr/models/kie_predictor/pytorch.py | 1 + doctr/models/kie_predictor/tensorflow.py | 6 +- doctr/models/predictor/base.py | 30 +++-- doctr/models/predictor/pytorch.py | 1 + doctr/models/predictor/tensorflow.py | 6 +- doctr/utils/geometry.py | 104 ++++++++++++++---- tests/common/test_utils_geometry.py | 9 +- .../pytorch/test_models_classification_pt.py | 23 +++- tests/pytorch/test_models_zoo_pt.py | 46 ++++++-- .../test_models_classification_tf.py | 22 +++- tests/tensorflow/test_models_zoo_tf.py | 46 ++++++-- 20 files changed, 328 insertions(+), 74 deletions(-) diff --git a/api/app/schemas.py b/api/app/schemas.py index 6f4085a294..b231a740f9 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -19,6 +19,8 @@ class KIEIn(BaseModel): straighten_pages: bool = Field(default=False, examples=[False]) det_bs: int = Field(default=2, examples=[2]) reco_bs: int = Field(default=128, examples=[128]) + disable_page_orientation: bool = Field(default=False, examples=[False]) + disable_crop_orientation: bool = Field(default=False, examples=[False]) bin_thresh: float = Field(default=0.1, examples=[0.1]) box_thresh: float = Field(default=0.1, examples=[0.1]) diff --git a/demo/app.py b/demo/app.py index a2368ec90e..60adba0fb8 100644 --- a/demo/app.py +++ b/demo/app.py @@ -72,6 +72,12 @@ def main(det_archs, reco_archs): st.sidebar.title("Parameters") assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True) st.sidebar.write("\n") + # Disable page orientation detection + disable_page_orientation = st.sidebar.checkbox("Disable page orientation detection", value=False) + st.sidebar.write("\n") + # Disable crop orientation detection + disable_crop_orientation = st.sidebar.checkbox("Disable crop orientation detection", value=False) + st.sidebar.write("\n") # Straighten pages straighten_pages = st.sidebar.checkbox("Straighten pages", value=False) st.sidebar.write("\n") @@ -89,7 +95,15 @@ def main(det_archs, reco_archs): else: with st.spinner("Loading model..."): predictor = load_predictor( - det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device + det_arch, + reco_arch, + assume_straight_pages, + straighten_pages, + disable_page_orientation, + disable_crop_orientation, + bin_thresh, + box_thresh, + forward_device, ) with st.spinner("Analyzing..."): diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index 9ce8532b2f..e3ced74d5f 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -37,6 +37,8 @@ def load_predictor( reco_arch: str, assume_straight_pages: bool, straighten_pages: bool, + disable_page_orientation: bool, + disable_crop_orientation: bool, bin_thresh: float, box_thresh: float, device: torch.device, @@ -49,6 +51,8 @@ def load_predictor( reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not + disable_page_orientation: whether to disable page orientation or not + disable_crop_orientation: whether to disable crop orientation or not bin_thresh: binarization threshold for the segmentation map box_thresh: minimal objectness score to consider a box device: torch.device, the device to load the predictor on @@ -65,6 +69,8 @@ def load_predictor( straighten_pages=straighten_pages, export_as_straight_boxes=straighten_pages, detect_orientation=not assume_straight_pages, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, ).to(device) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index 980ae628d8..6ca9614159 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -36,6 +36,8 @@ def load_predictor( reco_arch: str, assume_straight_pages: bool, straighten_pages: bool, + disable_page_orientation: bool, + disable_crop_orientation: bool, bin_thresh: float, box_thresh: float, device: tf.device, @@ -48,6 +50,8 @@ def load_predictor( reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not + disable_page_orientation: whether to disable page orientation or not + disable_crop_orientation: whether to disable crop orientation or not bin_thresh: binarization threshold for the segmentation map box_thresh: threshold for the detection boxes device: tf.device, the device to load the predictor on @@ -65,6 +69,8 @@ def load_predictor( straighten_pages=straighten_pages, export_as_straight_boxes=straighten_pages, detect_orientation=not assume_straight_pages, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, ) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 0524169afa..e6e5006f2e 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -283,13 +283,16 @@ Those architectures involve one stage of text detection, and one stage of text r You can pass specific boolean arguments to the predictor: -* `assume_straight_pages` -* `preserve_aspect_ratio` -* `symmetric_pad` +* `assume_straight_pages`: if you work with straight documents only, it will fit straight bounding boxes to the text areas. +* `preserve_aspect_ratio`: if you want to preserve the aspect ratio of your documents while resizing before sending them to the model. +* `symmetric_pad`: if you choose to preserve the aspect ratio, it will pad the image symmetrically and not from the bottom-right. Those 3 are going straight to the detection predictor, as mentioned above (in the detection part). +Additional arguments which can be passed to the `ocr_predictor` are: + * `export_as_straight_boxes`: If you work with rotated and skewed documents but you still want to export straight bounding boxes and not polygons, set it to True. +* `straighten_pages`: If you want to straighten the pages before sending them to the detection model, set it to True. For instance, this snippet instantiates an end-to-end ocr_predictor working with rotated documents, which preserves the aspect ratio of the documents, and returns polygons: @@ -298,6 +301,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with from doctr.model import ocr_predictor model = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) + Additionally, you can change the batch size of the underlying detection and recognition predictors to optimize the performance depending on your hardware: * `det_bs`: batch size for the detection model (default: 2) @@ -465,6 +469,30 @@ This is useful to detect (possible less) text regions more accurately with a hig out = predictor([input_page]) +* Disable page orientation classification + +If you deal with documents which contains only small rotations (~ -45 to 45 degrees), you can disable the page orientation classification to speed up the inference. + +This will only have an effect with `assume_straight_pages=False` and/or `straighten_pages=True` and/or `detect_orientation=True`. + +.. code:: python3 + + from doctr.model import ocr_predictor + model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_page_orientation=True) + + +* Disable crop orientation classification + +If you deal with documents which contains only horizontal text, you can disable the crop orientation classification to speed up the inference. + +This will only have an effect with `assume_straight_pages=False` and/or `straighten_pages=True`. + +.. code:: python3 + + from doctr.model import ocr_predictor + model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_crop_orientation=True) + + * Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model. .. code:: python3 diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index d061250565..96f5c468ff 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List, Union +from typing import List, Optional, Union import numpy as np import torch @@ -27,12 +27,12 @@ class OrientationPredictor(nn.Module): def __init__( self, - pre_processor: PreProcessor, - model: nn.Module, + pre_processor: Optional[PreProcessor], + model: Optional[nn.Module], ) -> None: super().__init__() - self.pre_processor = pre_processor - self.model = model.eval() + self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None + self.model = model.eval() if isinstance(model, nn.Module) else None @torch.inference_mode() def forward( @@ -43,12 +43,16 @@ def forward( if any(input.ndim != 3 for input in inputs): raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") + if self.model is None or self.pre_processor is None: + # predictor is disabled + return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)] + processed_batches = self.pre_processor(inputs) _params = next(self.model.parameters()) self.model, processed_batches = set_device_and_dtype( self.model, processed_batches, _params.device, _params.dtype ) - predicted_batches = [self.model(batch) for batch in processed_batches] + predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc] # confidence probs = [ torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches @@ -57,7 +61,7 @@ def forward( predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches] class_idxs = [int(pred) for batch in predicted_batches for pred in batch] - classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] + classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr] confs = [round(float(p), 2) for prob in probs for p in prob] return [class_idxs, classes, confs] diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index 95295584f1..e3756e6e83 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List, Union +from typing import List, Optional, Union import numpy as np import tensorflow as tf @@ -29,11 +29,11 @@ class OrientationPredictor(NestedObject): def __init__( self, - pre_processor: PreProcessor, - model: keras.Model, + pre_processor: Optional[PreProcessor], + model: Optional[keras.Model], ) -> None: - self.pre_processor = pre_processor - self.model = model + self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None + self.model = model if isinstance(model, keras.Model) else None def __call__( self, @@ -43,6 +43,10 @@ def __call__( if any(input.ndim != 3 for input in inputs): raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") + if self.model is None or self.pre_processor is None: + # predictor is disabled + return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)] + processed_batches = self.pre_processor(inputs) predicted_batches = [self.model(batch, training=False) for batch in processed_batches] diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index fccd5b5979..16ffea0051 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -34,7 +34,13 @@ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] -def _orientation_predictor(arch: Any, pretrained: bool, model_type: str, **kwargs: Any) -> OrientationPredictor: +def _orientation_predictor( + arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any +) -> OrientationPredictor: + if disabled: + # Case where the orientation predictor is disabled + return OrientationPredictor(None, None) + if isinstance(arch, str): if arch not in ORIENTATION_ARCHS: raise ValueError(f"unknown architecture '{arch}'") diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py index 53d807898e..0b6cd28dc7 100644 --- a/doctr/models/kie_predictor/base.py +++ b/doctr/models/kie_predictor/base.py @@ -46,4 +46,8 @@ def __init__( assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs ) + # Remove the following arguments from kwargs after initialization of the parent class + kwargs.pop("disable_page_orientation", None) + kwargs.pop("disable_crop_orientation", None) + self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs) diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 4bcedc7064..c7ffa140c5 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -129,6 +129,7 @@ def forward( dict_loc_preds[class_name], channels_last=channels_last, assume_straight_pages=self.assume_straight_pages, + assume_horizontal=self._page_orientation_disabled, ) # Rectify crop orientation crop_orientations: Any = {} diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index d9d765bbe6..b73f651fc5 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -122,7 +122,11 @@ def __call__( crops = {} for class_name in dict_loc_preds.keys(): crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( - pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages + pages, + dict_loc_preds[class_name], + channels_last=True, + assume_straight_pages=self.assume_straight_pages, + assume_horizontal=self._page_orientation_disabled, ) # Rectify crop orientation diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 0469b32ea3..42f9142497 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -48,9 +48,15 @@ def __init__( ) -> None: self.assume_straight_pages = assume_straight_pages self.straighten_pages = straighten_pages - self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) + self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False) + self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False) + self.crop_orientation_predictor = ( + None + if assume_straight_pages + else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled) + ) self.page_orientation_predictor = ( - page_orientation_predictor(pretrained=True) + page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled) if detect_orientation or straighten_pages or not assume_straight_pages else None ) @@ -112,13 +118,18 @@ def _generate_crops( loc_preds: List[np.ndarray], channels_last: bool, assume_straight_pages: bool = False, + assume_horizontal: bool = False, ) -> List[List[np.ndarray]]: - extraction_fn = extract_crops if assume_straight_pages else extract_rcrops - - crops = [ - extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] - for page, _boxes in zip(pages, loc_preds) - ] + if assume_straight_pages: + crops = [ + extract_crops(page, _boxes[:, :4], channels_last=channels_last) + for page, _boxes in zip(pages, loc_preds) + ] + else: + crops = [ + extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal) + for page, _boxes in zip(pages, loc_preds) + ] return crops @staticmethod @@ -127,8 +138,9 @@ def _prepare_crops( loc_preds: List[np.ndarray], channels_last: bool, assume_straight_pages: bool = False, + assume_horizontal: bool = False, ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: - crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) + crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal) # Avoid sending zero-sized crops is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 7cbf383a06..326b89e5ff 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -123,6 +123,7 @@ def forward( loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages, + assume_horizontal=self._page_orientation_disabled, ) # Rectify crop orientation and get crop orientation predictions crop_orientations: Any = [] diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index f736614879..8f58062fd5 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -116,7 +116,11 @@ def __call__( # Crop images crops, loc_preds = self._prepare_crops( - pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages + pages, + loc_preds, + channels_last=True, + assume_straight_pages=self.assume_straight_pages, + assume_horizontal=self._page_orientation_disabled, ) # Rectify crop orientation and get crop orientation predictions crop_orientations: Any = [] diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index aceae8ca43..d16ac3df86 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -431,7 +431,7 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True def extract_rcrops( - img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True + img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False ) -> List[np.ndarray]: """Created cropped images from list of rotated bounding boxes @@ -441,6 +441,7 @@ def extract_rcrops( polys: bounding boxes of shape (N, 4, 2) dtype: target data type of bounding boxes channels_last: whether the channel dimensions is the last one instead of the last one + assume_horizontal: whether the boxes are assumed to be only horizontally oriented Returns: ------- @@ -458,22 +459,87 @@ def extract_rcrops( _boxes[:, :, 0] *= width _boxes[:, :, 1] *= height - src_pts = _boxes[:, :3].astype(np.float32) - # Preserve size - d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1) - d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1) - # (N, 3, 2) - dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype) - dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1 - dst_pts[:, 2, 1] = d2 - 1 - # Use a warp transformation to extract the crop - crops = [ - cv2.warpAffine( - img if channels_last else img.transpose(1, 2, 0), - # Transformation matrix - cv2.getAffineTransform(src_pts[idx], dst_pts[idx]), - (int(d1[idx]), int(d2[idx])), - ) - for idx in range(_boxes.shape[0]) - ] + src_img = img if channels_last else img.transpose(1, 2, 0) + + # Handle only horizontal oriented boxes + if assume_horizontal: + crops = [] + + for box in _boxes: + # Calculate the centroid of the quadrilateral + centroid = np.mean(box, axis=0) + + # Divide the points into left and right + left_points = box[box[:, 0] < centroid[0]] + right_points = box[box[:, 0] >= centroid[0]] + + # Sort the left points according to the y-axis + left_points = left_points[np.argsort(left_points[:, 1])] + top_left_pt = left_points[0] + bottom_left_pt = left_points[-1] + # Sort the right points according to the y-axis + right_points = right_points[np.argsort(right_points[:, 1])] + top_right_pt = right_points[0] + bottom_right_pt = right_points[-1] + box_points = np.array( + [top_left_pt, bottom_left_pt, top_right_pt, bottom_right_pt], + dtype=dtype, + ) + + # Get the width and height of the rectangle that will contain the warped quadrilateral + width_upper = np.linalg.norm(top_right_pt - top_left_pt) + width_lower = np.linalg.norm(bottom_right_pt - bottom_left_pt) + height_left = np.linalg.norm(bottom_left_pt - top_left_pt) + height_right = np.linalg.norm(bottom_right_pt - top_right_pt) + + # Get the maximum width and height + rect_width = max(int(width_upper), int(width_lower)) + rect_height = max(int(height_left), int(height_right)) + + dst_pts = np.array( + [ + [0, 0], # top-left + # bottom-left + [0, rect_height - 1], + # top-right + [rect_width - 1, 0], + # bottom-right + [rect_width - 1, rect_height - 1], + ], + dtype=dtype, + ) + + # Get the perspective transform matrix using the box points + affine_mat = cv2.getPerspectiveTransform(box_points, dst_pts) + + # Perform the perspective warp to get the rectified crop + crop = cv2.warpPerspective( + src_img, + affine_mat, + (rect_width, rect_height), + ) + + # Add the crop to the list of crops + crops.append(crop) + + # Handle any oriented boxes + else: + src_pts = _boxes[:, :3].astype(np.float32) + # Preserve size + d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1) + d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1) + # (N, 3, 2) + dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype) + dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1 + dst_pts[:, 2, 1] = d2 - 1 + # Use a warp transformation to extract the crop + crops = [ + cv2.warpAffine( + src_img, + # Transformation matrix + cv2.getAffineTransform(src_pts[idx], dst_pts[idx]), + (int(d1[idx]), int(d2[idx])), + ) + for idx in range(_boxes.shape[0]) + ] return crops # type: ignore[return-value] diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py index 984019e06c..afeed8a87c 100644 --- a/tests/common/test_utils_geometry.py +++ b/tests/common/test_utils_geometry.py @@ -234,7 +234,8 @@ def test_extract_crops(mock_pdf): assert geometry.extract_crops(doc_img, np.zeros((0, 4))) == [] -def test_extract_rcrops(mock_pdf): +@pytest.mark.parametrize("assume_horizontal", [True, False]) +def test_extract_rcrops(mock_pdf, assume_horizontal): doc_img = DocumentFile.from_pdf(mock_pdf)[0] num_crops = 2 rel_boxes = np.array( @@ -255,9 +256,9 @@ def test_extract_rcrops(mock_pdf): abs_boxes = abs_boxes.astype(np.int64) with pytest.raises(AssertionError): - geometry.extract_rcrops(doc_img, np.zeros((1, 8))) + geometry.extract_rcrops(doc_img, np.zeros((1, 8)), assume_horizontal=assume_horizontal) for boxes in (rel_boxes, abs_boxes): - croped_imgs = geometry.extract_rcrops(doc_img, boxes) + croped_imgs = geometry.extract_rcrops(doc_img, boxes, assume_horizontal=assume_horizontal) # Number of crops assert len(croped_imgs) == num_crops # Data type and shape @@ -265,4 +266,4 @@ def test_extract_rcrops(mock_pdf): assert all(crop.ndim == 3 for crop in croped_imgs) # No box - assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2))) == [] + assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2)), assume_horizontal=assume_horizontal) == [] diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index f35a1ac9de..4c0b571da9 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -112,7 +112,6 @@ def test_classification_zoo(arch_name): with torch.no_grad(): out = predictor(input_tensor) - out = predictor(input_tensor) class_idxs, classes, confs = out[0], out[1], out[2] assert isinstance(class_idxs, list) and len(class_idxs) == batch_size assert isinstance(classes, list) and len(classes) == batch_size @@ -134,6 +133,16 @@ def test_crop_orientation_model(mock_text_box): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test with disabled predictor + classifier = classification.crop_orientation_predictor( + "mobilenet_v3_small_crop_orientation", pretrained=False, disabled=True + ) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90]) == [ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1.0, 1.0, 1.0, 1.0], + ] + # Test custom model loading classifier = classification.crop_orientation_predictor( classification.mobilenet_v3_small_crop_orientation(pretrained=True) @@ -150,12 +159,22 @@ def test_page_orientation_model(mock_payslip): text_box_270 = np.rot90(text_box_0, 1) text_box_180 = np.rot90(text_box_0, 2) text_box_90 = np.rot90(text_box_0, 3) - classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) + classifier = classification.page_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] # 270 degrees is equivalent to -90 degrees assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test with disabled predictor + classifier = classification.page_orientation_predictor( + "mobilenet_v3_small_page_orientation", pretrained=False, disabled=True + ) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90]) == [ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1.0, 1.0, 1.0, 1.0], + ] + # Test custom model loading classifier = classification.page_orientation_predictor( classification.mobilenet_v3_small_page_orientation(pretrained=True) diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 9be66edd7b..3ea22ca9b6 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -25,14 +25,18 @@ def __call__(self, loc_preds): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation", [ - [True, False], - [False, False], - [True, True], + [True, False, False, False], + [False, False, True, True], + [True, True, False, False], + [False, True, True, True], + [True, False, True, False], ], ) -def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_ocrpredictor( + mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation +): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -64,6 +68,15 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa detect_language=True, resolve_blocks=True, resolve_lines=True, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, + ) + + assert ( + predictor._page_orientation_disabled if disable_page_orientation else not predictor._page_orientation_disabled + ) + assert ( + predictor._crop_orientation_disabled if disable_crop_orientation else not predictor._crop_orientation_disabled ) if assume_straight_pages: @@ -167,14 +180,18 @@ def test_trained_ocr_predictor(mock_payslip): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation", [ - [True, False], - [False, False], - [True, True], + [True, False, False, False], + [False, False, True, True], + [True, True, False, False], + [False, True, True, True], + [True, False, True, False], ], ) -def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_kiepredictor( + mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation +): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -206,6 +223,15 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa detect_language=True, resolve_blocks=True, resolve_lines=True, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, + ) + + assert ( + predictor._page_orientation_disabled if disable_page_orientation else not predictor._page_orientation_disabled + ) + assert ( + predictor._crop_orientation_disabled if disable_crop_orientation else not predictor._crop_orientation_disabled ) if assume_straight_pages: diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 77eb8253ca..731e4dbd8b 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -113,6 +113,16 @@ def test_crop_orientation_model(mock_text_box): assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test with disabled predictor + classifier = classification.crop_orientation_predictor( + "mobilenet_v3_small_crop_orientation", pretrained=False, disabled=True + ) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90]) == [ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1.0, 1.0, 1.0, 1.0], + ] + # Test custom model loading classifier = classification.crop_orientation_predictor( classification.mobilenet_v3_small_crop_orientation(pretrained=True) @@ -129,12 +139,22 @@ def test_page_orientation_model(mock_payslip): text_box_270 = np.rot90(text_box_0, 1) text_box_180 = np.rot90(text_box_0, 2) text_box_90 = np.rot90(text_box_0, 3) - classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) + classifier = classification.page_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] # 270 degrees is equivalent to -90 degrees assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + # Test with disabled predictor + classifier = classification.page_orientation_predictor( + "mobilenet_v3_small_page_orientation", pretrained=False, disabled=True + ) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90]) == [ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1.0, 1.0, 1.0, 1.0], + ] + # Test custom model loading classifier = classification.page_orientation_predictor( classification.mobilenet_v3_small_page_orientation(pretrained=True) diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 4b7e606563..401aed68c2 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -25,14 +25,18 @@ def __call__(self, loc_preds): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation", [ - [True, False], - [False, False], - [True, True], + [True, False, False, False], + [False, False, True, True], + [True, True, False, False], + [False, True, True, True], + [True, False, True, False], ], ) -def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_ocrpredictor( + mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation +): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -61,6 +65,15 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa detect_language=True, resolve_blocks=True, resolve_lines=True, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, + ) + + assert ( + predictor._page_orientation_disabled if disable_page_orientation else not predictor._page_orientation_disabled + ) + assert ( + predictor._crop_orientation_disabled if disable_crop_orientation else not predictor._crop_orientation_disabled ) if assume_straight_pages: @@ -166,14 +179,18 @@ def test_trained_ocr_predictor(mock_payslip): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation", [ - [True, False], - [False, False], - [True, True], + [True, False, False, False], + [False, False, True, True], + [True, True, False, False], + [False, True, True, True], + [True, False, True, False], ], ) -def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_kiepredictor( + mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation +): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -202,6 +219,15 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa detect_language=True, resolve_blocks=True, resolve_lines=True, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, + ) + + assert ( + predictor._page_orientation_disabled if disable_page_orientation else not predictor._page_orientation_disabled + ) + assert ( + predictor._crop_orientation_disabled if disable_crop_orientation else not predictor._crop_orientation_disabled ) if assume_straight_pages: From 7b7f7f38c817021e37594d0bd7de81ebe369530f Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 30 Sep 2024 12:03:39 +0200 Subject: [PATCH 07/18] [build] NumPy 2.0 support (#1709) --- .conda/meta.yaml | 2 +- doctr/transforms/functional/pytorch.py | 2 +- doctr/utils/metrics.py | 2 +- pyproject.toml | 4 +-- tests/conftest.py | 39 ++++++++++++++++---------- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index fcac492132..7feb3a1bf9 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -20,7 +20,7 @@ requirements: - setuptools run: - - numpy >=1.16.0, <2.0.0 + - numpy >=1.16.0, <3.0.0 - scipy >=1.4.0, <2.0.0 - pillow >=9.2.0 - h5py >=3.1.0, <4.0.0 diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index 65649ea2c8..740769d99c 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -89,7 +89,7 @@ def rotate_sample( rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2] rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1] - return rotated_img, np.clip(rotated_geoms, 0, 1) + return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1) def crop_detection( diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index faea10a3ab..6947298ede 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -149,7 +149,7 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: right = np.minimum(r1, r2.T) bot = np.minimum(b1, b2.T) - intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf) + intersection = np.clip(right - left, 0, np.inf) * np.clip(bot - top, 0, np.inf) union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection iou_mat = intersection / union diff --git a/pyproject.toml b/pyproject.toml index c208d98652..c0b209f535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dynamic = ["version"] dependencies = [ # For proper typing, mypy needs numpy>=1.20.0 (cf. https://github.com/numpy/numpy/pull/16515) # Additional typing support is brought by numpy>=1.22.4, but core build sticks to >=1.16.0 - "numpy>=1.16.0,<2.0.0", + "numpy>=1.16.0,<3.0.0", "scipy>=1.4.0,<2.0.0", "h5py>=3.1.0,<4.0.0", "opencv-python>=4.5.0,<5.0.0", @@ -75,7 +75,6 @@ contrib = [ testing = [ "pytest>=5.3.2", "coverage[toml]>=4.5.4", - "hdf5storage>=0.1.18", "onnxruntime>=1.11.0", "requests>=2.20.0", "psutil>=5.9.5" @@ -112,7 +111,6 @@ dev = [ # Testing "pytest>=5.3.2", "coverage[toml]>=4.5.4", - "hdf5storage>=0.1.18", "onnxruntime>=1.11.0", "requests>=2.20.0", "psutil>=5.9.5", diff --git a/tests/conftest.py b/tests/conftest.py index 61bceeb392..9757b4eaad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,6 @@ from io import BytesIO import cv2 -import hdf5storage import numpy as np import pytest import requests @@ -257,24 +256,34 @@ def mock_imgur5k(tmpdir_factory, mock_image_stream): def mock_svhn_dataset(tmpdir_factory, mock_image_stream): root = tmpdir_factory.mktemp("datasets") svhn_root = root.mkdir("svhn") + train_root = svhn_root.mkdir("train") file = BytesIO(mock_image_stream) + + # NOTE: hdf5storage seems not to be maintained anymore, ref.: https://github.com/frejanordsiek/hdf5storage/pull/134 + # Instead we download the mocked data which was generated using the following code: # ascii image names - first = np.array([[49], [46], [112], [110], [103]], dtype=np.int16) # 1.png - second = np.array([[50], [46], [112], [110], [103]], dtype=np.int16) # 2.png - third = np.array([[51], [46], [112], [110], [103]], dtype=np.int16) # 3.png + # first = np.array([[49], [46], [112], [110], [103]], dtype=np.int16) # 1.png + # second = np.array([[50], [46], [112], [110], [103]], dtype=np.int16) # 2.png + # third = np.array([[51], [46], [112], [110], [103]], dtype=np.int16) # 3.png # labels: label is also ascii - label = { - "height": [35, 35, 35, 35], - "label": [1, 1, 3, 7], - "left": [116, 128, 137, 151], - "top": [27, 29, 29, 26], - "width": [15, 10, 17, 17], - } - - matcontent = {"digitStruct": {"name": [first, second, third], "bbox": [label, label, label]}} + # label = { + # "height": [35, 35, 35, 35], + # "label": [1, 1, 3, 7], + # "left": [116, 128, 137, 151], + # "top": [27, 29, 29, 26], + # "width": [15, 10, 17, 17], + # } + + # matcontent = {"digitStruct": {"name": [first, second, third], "bbox": [label, label, label]}} # Mock train data - train_root = svhn_root.mkdir("train") - hdf5storage.write(matcontent, filename=train_root.join("digitStruct.mat")) + # hdf5storage.write(matcontent, filename=train_root.join("digitStruct.mat")) + + # Downloading the mocked data + url = "https://github.com/mindee/doctr/releases/download/v0.9.0/digitStruct.mat" + response = requests.get(url) + with open(train_root.join("digitStruct.mat"), "wb") as f: + f.write(response.content) + for i in range(3): fn = train_root.join(f"{i + 1}.png") with open(fn, "wb") as f: From df762ed90010db4df9f4cb5692b52c2a2e5dc819 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 30 Sep 2024 14:52:55 +0200 Subject: [PATCH 08/18] [Fix] Remove image padding after rotation correction with `straighten_pages=True` (#1731) --- doctr/models/predictor/base.py | 6 +++--- doctr/utils/geometry.py | 21 +++++++++++++++++++++ tests/common/test_utils_geometry.py | 11 +++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 42f9142497..530590bc61 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -8,7 +8,7 @@ import numpy as np from doctr.models.builder import DocumentBuilder -from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image +from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds from ..classification import crop_orientation_predictor, page_orientation_predictor @@ -107,8 +107,8 @@ def _straighten_pages( ] ) return [ - # expand if height and width are not equal - rotate_image(page, angle, expand=page.shape[0] != page.shape[1]) + # expand if height and width are not equal, then remove the padding + remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1])) for page, angle in zip(pages, origin_pages_orientations) ] diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index d16ac3df86..21ad0dd7f4 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -20,6 +20,7 @@ "rotate_boxes", "compute_expanded_shape", "rotate_image", + "remove_image_padding", "estimate_page_angle", "convert_to_relative_coords", "rotate_abs_geoms", @@ -351,6 +352,26 @@ def rotate_image( return rot_img +def remove_image_padding(image: np.ndarray) -> np.ndarray: + """Remove black border padding from an image + + Args: + ---- + image: numpy tensor to remove padding from + + Returns: + ------- + Image with padding removed + """ + # Find the bounding box of the non-black region + rows = np.any(image, axis=1) + cols = np.any(image, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return image[rmin : rmax + 1, cmin : cmax + 1] + + def estimate_page_angle(polys: np.ndarray) -> float: """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the estimated angle ccw in degrees diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py index afeed8a87c..d2524a6ab7 100644 --- a/tests/common/test_utils_geometry.py +++ b/tests/common/test_utils_geometry.py @@ -142,6 +142,17 @@ def test_rotate_image(): assert rotated[0, :, 0].sum() <= 1 +def test_remove_image_padding(): + img = np.ones((32, 64, 3), dtype=np.float32) + padded = np.pad(img, ((10, 10), (20, 20), (0, 0))) + cropped = geometry.remove_image_padding(padded) + assert np.all(cropped == img) + + # No padding + cropped = geometry.remove_image_padding(img) + assert np.all(cropped == img) + + @pytest.mark.parametrize( "abs_geoms, img_size, rel_geoms", [ From dccc26b15b5b276051c5527631b6ab84ee330a38 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 1 Oct 2024 10:34:04 +0200 Subject: [PATCH 09/18] [TF] First changes on the road to Keras v3 (#1724) --- .../using_doctr/custom_models_training.rst | 20 ++++---- .../source/using_doctr/using_model_export.rst | 2 +- doctr/io/image/tensorflow.py | 2 +- .../classification/magc_resnet/tensorflow.py | 15 ++++-- .../classification/mobilenet/tensorflow.py | 22 ++++---- .../classification/predictor/tensorflow.py | 6 +-- .../classification/resnet/tensorflow.py | 30 +++++++---- .../classification/textnet/tensorflow.py | 14 +++-- doctr/models/classification/vgg/tensorflow.py | 12 +++-- doctr/models/classification/vit/tensorflow.py | 12 +++-- .../differentiable_binarization/tensorflow.py | 37 +++++++++----- doctr/models/detection/fast/tensorflow.py | 18 ++++--- doctr/models/detection/linknet/tensorflow.py | 20 +++++--- .../models/detection/predictor/tensorflow.py | 4 +- doctr/models/factory/hub.py | 13 +++-- doctr/models/modules/layers/tensorflow.py | 2 +- .../models/modules/transformer/tensorflow.py | 2 +- .../modules/vision_transformer/tensorflow.py | 2 +- doctr/models/recognition/crnn/tensorflow.py | 15 +++--- doctr/models/recognition/master/tensorflow.py | 11 ++-- doctr/models/recognition/parseq/tensorflow.py | 13 +++-- doctr/models/recognition/sar/tensorflow.py | 9 ++-- doctr/models/recognition/vitstr/tensorflow.py | 11 ++-- doctr/models/utils/tensorflow.py | 26 ++++------ pyproject.toml | 5 +- .../train_tensorflow_character.py | 12 +++-- .../train_tensorflow_orientation.py | 12 +++-- references/detection/evaluate_tensorflow.py | 2 +- references/detection/train_tensorflow.py | 22 ++++---- references/recognition/evaluate_tensorflow.py | 2 +- references/recognition/train_tensorflow.py | 18 ++++--- .../pytorch/test_models_classification_pt.py | 2 +- .../test_models_classification_tf.py | 12 +++-- tests/tensorflow/test_models_detection_tf.py | 11 ++-- tests/tensorflow/test_models_factory.py | 51 +++++++++---------- .../tensorflow/test_models_recognition_tf.py | 9 ++-- tests/tensorflow/test_models_utils_tf.py | 18 +++---- 37 files changed, 287 insertions(+), 207 deletions(-) diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index ecf88d8116..13e4640a36 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -22,19 +22,19 @@ This section shows how you can easily load a custom trained model in docTR. # Load custom detection model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("/weights") + det_model.load_weights("") predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True) # Load custom recognition model reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("/weights") + reco_model.load_weights("") predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True) # Load custom detection and recognition model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("/weights") + det_model.load_weights("") reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("/weights") + reco_model.load_weights("") predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False) .. tab:: PyTorch @@ -77,7 +77,7 @@ Load a custom recognition model trained on another vocabulary as the default one from doctr.datasets import VOCABS reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=VOCABS["german"]) - reco_model.load_weights("/weights") + reco_model.load_weights("") predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True) @@ -106,7 +106,7 @@ Load a custom trained KIE detection model: from doctr.models import kie_predictor, db_resnet50 det_model = db_resnet50(pretrained=False, pretrained_backbone=False, class_names=['total', 'date']) - det_model.load_weights("/weights") + det_model.load_weights("") kie_predictor(det_arch=det_model, reco_arch='crnn_vgg16_bn', pretrained=True) .. tab:: PyTorch @@ -136,9 +136,9 @@ Load a model with customized Preprocessor: from doctr.models import db_resnet50, crnn_vgg16_bn det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("/weights") + det_model.load_weights("") reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("/weights") + reco_model.load_weights("") det_predictor = DetectionPredictor( PreProcessor( @@ -233,9 +233,9 @@ Loading your custom trained orientation classification model from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) - custom_page_orientation_model.load_weights("/weights") + custom_page_orientation_model.load_weights("") custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) - custom_crop_orientation_model.load_weights("/weights") + custom_crop_orientation_model.load_weights("") predictor = ocr_predictor( pretrained=True, diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index c62c36169b..48f570f699 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -31,7 +31,7 @@ Advantages: .. code:: python3 import tensorflow as tf - from tensorflow.keras import mixed_precision + from keras import mixed_precision mixed_precision.set_global_policy('mixed_float16') predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True) diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py index 28fb2fadd5..3b1f1ed0e2 100644 --- a/doctr/io/image/tensorflow.py +++ b/doctr/io/image/tensorflow.py @@ -7,8 +7,8 @@ import numpy as np import tensorflow as tf +from keras.utils import img_to_array from PIL import Image -from tensorflow.keras.utils import img_to_array from doctr.utils.common_types import AbstractPath diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index e791e661bf..12f7c6beea 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import activations, layers +from keras.models import Sequential from doctr.datasets import VOCABS @@ -26,7 +26,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0", }, } @@ -57,6 +57,7 @@ def __init__( self.headers = headers # h self.inplanes = inplanes # C self.attn_scale = attn_scale + self.ratio = ratio self.planes = int(inplanes * ratio) self.single_header_inplanes = int(inplanes / headers) # C / h @@ -97,7 +98,7 @@ def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor: if self.attn_scale and self.headers > 1: context_mask = context_mask / math.sqrt(self.single_header_inplanes) # B*h, 1, H*W, 1 - context_mask = tf.keras.activations.softmax(context_mask, axis=2) + context_mask = activations.softmax(context_mask, axis=2) # Compute context # B*h, 1, C/h, 1 @@ -153,7 +154,11 @@ def _magc_resnet( ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index 2156cf1f50..6250abc666 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import layers +from keras.models import Sequential from ....datasets import VOCABS from ...utils import conv_sequence, load_pretrained_params @@ -32,42 +32,42 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0", }, "mobilenet_v3_large_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0", }, "mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0", }, "mobilenet_v3_small_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0", }, "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (128, 128, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0", }, "mobilenet_v3_small_page_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (512, 512, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0", }, } @@ -297,7 +297,11 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index e3756e6e83..ba26e1db54 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +from keras import Model from doctr.models.preprocessor import PreProcessor from doctr.utils.repr import NestedObject @@ -30,10 +30,10 @@ class OrientationPredictor(NestedObject): def __init__( self, pre_processor: Optional[PreProcessor], - model: Optional[keras.Model], + model: Optional[Model], ) -> None: self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None - self.model = model if isinstance(model, keras.Model) else None + self.model = model if isinstance(model, Model) else None def __call__( self, diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 7648e5f8d0..3e78ae0ae2 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -7,9 +7,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.applications import ResNet50 -from tensorflow.keras.models import Sequential +from keras import layers +from keras.applications import ResNet50 +from keras.models import Sequential from doctr.datasets import VOCABS @@ -24,35 +24,35 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0", }, "resnet31": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0", }, "resnet34": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0", }, "resnet50": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0", }, "resnet34_wide": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0", }, } @@ -212,7 +212,11 @@ def _resnet( ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model @@ -357,7 +361,13 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs["resnet50"]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, + default_cfgs["resnet50"]["url"], + skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]), + ) return model diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index f30d5d823c..3d79b15f09 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from tensorflow.keras import Sequential, layers +from keras import Sequential, layers from doctr.datasets import VOCABS @@ -22,21 +22,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0", }, "textnet_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0", }, "textnet_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0", }, } @@ -113,7 +113,11 @@ def _textnet( model = TextNet(cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index 259ed9f888..d9e7bb374b 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -6,8 +6,8 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import layers +from keras.models import Sequential from doctr.datasets import VOCABS @@ -22,7 +22,7 @@ "std": (1.0, 1.0, 1.0), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0", }, } @@ -83,7 +83,11 @@ def _vgg( model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 4b73b49ac9..28ff2e244e 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Sequential, layers +from keras import Sequential, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import EncoderBlock @@ -25,14 +25,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0", }, "vit_b": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0", }, } @@ -123,7 +123,11 @@ def _vit( model = VisionTransformer(cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index df9935b042..7fdbd43ce0 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -10,9 +10,8 @@ import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -from tensorflow.keras.applications import ResNet50 +from keras import Model, Sequential, layers, losses +from keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params @@ -29,13 +28,13 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0", }, "db_mobilenet_v3_large": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0", }, } @@ -81,7 +80,7 @@ def build_upsampling( if dilation_factor > 1: _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest")) - module = keras.Sequential(_layers) + module = Sequential(_layers) return module @@ -104,7 +103,7 @@ def call( return layers.concatenate(results) -class DBNet(_DBNet, keras.Model, NestedObject): +class DBNet(_DBNet, Model, NestedObject): """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" `_. @@ -147,14 +146,14 @@ def __init__( _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape] output_shape = tuple(self.fpn(_inputs).shape) - self.probability_head = keras.Sequential([ + self.probability_head = Sequential([ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), layers.BatchNormalization(), layers.Activation("relu"), layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), ]) - self.threshold_head = keras.Sequential([ + self.threshold_head = Sequential([ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), layers.BatchNormalization(), @@ -206,7 +205,7 @@ def compute_loss( # Focal loss focal_scale = 10.0 - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) # Convert logits to prob, compute gamma factor p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) @@ -307,7 +306,12 @@ def _db_resnet( model = DBNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) return model @@ -326,6 +330,10 @@ def _db_mobilenet( # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg["input_shape"] = input_shape or _cfg["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) # Feature extractor feat_extractor = IntermediateLayerGetter( @@ -341,7 +349,12 @@ def _db_mobilenet( model = DBNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) return model diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 69998a2303..80fc31fea3 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -10,8 +10,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import Sequential, layers +from keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params @@ -29,19 +28,19 @@ "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-959daecb.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0", }, "fast_small": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-f1617503.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0", }, "fast_base": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-255e2ac3.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0", }, } @@ -100,7 +99,7 @@ def __init__( super().__init__(_layers) -class FAST(_FAST, keras.Model, NestedObject): +class FAST(_FAST, Model, NestedObject): """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" `_. @@ -336,7 +335,12 @@ def _fast( model = FAST(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) # Build the model for reparameterization to access the layers _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False) diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index ff11dbe477..683c49373a 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -10,8 +10,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import Model, Sequential, layers +from keras import Model, Sequential, layers, losses from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 @@ -27,19 +26,19 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0", }, "linknet_resnet34": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0", }, "linknet_resnet50": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0", }, } @@ -90,7 +89,7 @@ def extra_repr(self) -> str: return f"out_chans={self.out_chans}" -class LinkNet(_LinkNet, keras.Model): +class LinkNet(_LinkNet, Model): """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. @@ -187,7 +186,7 @@ def compute_loss( seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) seg_mask = tf.cast(seg_mask, tf.float32) - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) proba_map = tf.sigmoid(out_map) # Focal loss @@ -277,7 +276,12 @@ def _linknet( model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) return model diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index 14f38172df..a7ccd4a9ac 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +from keras import Model from doctr.models.detection._utils import _remove_padding from doctr.models.preprocessor import PreProcessor @@ -30,7 +30,7 @@ class DetectionPredictor(NestedObject): def __init__( self, pre_processor: PreProcessor, - model: keras.Model, + model: Model, ) -> None: self.pre_processor = pre_processor self.model = model diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 41cd91579a..b5844dd30b 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -20,7 +20,6 @@ get_token_permission, hf_hub_download, login, - snapshot_download, ) from doctr import models @@ -28,6 +27,8 @@ if is_torch_available(): import torch +elif is_tf_available(): + import tensorflow as tf __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] @@ -74,7 +75,9 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task weights_path = save_directory / "pytorch_model.bin" torch.save(model.state_dict(), weights_path) elif is_tf_available(): - weights_path = save_directory / "tf_model" / "weights" + weights_path = save_directory / "tf_model.weights.h5" + # NOTE: `model.build` is not an option because it doesn't runs in eager mode + _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.save_weights(str(weights_path)) config_path = save_directory / "config.json" @@ -225,7 +228,9 @@ def from_hub(repo_id: str, **kwargs: Any): state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu") model.load_state_dict(state_dict) else: # tf - repo_path = snapshot_download(repo_id, **kwargs) - model.load_weights(os.path.join(repo_path, "tf_model", "weights")) + weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs) + # NOTE: `model.build` is not an option because it doesn't runs in eager mode + _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) + model.load_weights(weights) return model diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py index 68849fbf6e..b1019be778 100644 --- a/doctr/models/modules/layers/tensorflow.py +++ b/doctr/models/modules/layers/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py index 403f99117d..eef4f3dbea 100644 --- a/doctr/models/modules/transformer/tensorflow.py +++ b/doctr/models/modules/transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py index 8386172eb1..a73aa4c706 100644 --- a/doctr/models/modules/vision_transformer/tensorflow.py +++ b/doctr/models/modules/vision_transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Tuple import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 5ec48c4f0e..d366bfc14b 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Model, Sequential +from keras import layers +from keras.models import Model, Sequential from doctr.datasets import VOCABS @@ -24,21 +24,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["legacy_french"], - "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0", }, "crnn_mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0", }, "crnn_mobilenet_v3_large": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0", }, } @@ -128,7 +128,7 @@ class CRNN(RecognitionModel, Model): def __init__( self, - feature_extractor: tf.keras.Model, + feature_extractor: Model, vocab: str, rnn_units: int = 128, exportable: bool = False, @@ -247,7 +247,8 @@ def _crnn( model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]) return model diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index a3ecadcc15..5b8192dee6 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from doctr.models.classification import magc_resnet31 @@ -25,7 +25,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/master-a8232e9f.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0", }, } @@ -51,7 +51,7 @@ class MASTER(_MASTER, Model): def __init__( self, - feature_extractor: tf.keras.Model, + feature_extractor: Model, vocab: str, d_model: int = 512, dff: int = 2048, @@ -292,7 +292,10 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 1365a6ac12..bca7806903 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -10,7 +10,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward @@ -27,7 +27,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/parseq-24cf693e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0", }, } @@ -43,7 +43,7 @@ class CharEmbedding(layers.Layer): def __init__(self, vocab_size: int, d_model: int): super(CharEmbedding, self).__init__() - self.embedding = tf.keras.layers.Embedding(vocab_size, d_model) + self.embedding = layers.Embedding(vocab_size, d_model) self.d_model = d_model def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: @@ -238,7 +238,7 @@ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple def decode( self, target: tf.Tensor, - memory: tf, + memory: tf.Tensor, target_mask: Optional[tf.Tensor] = None, target_query: Optional[tf.Tensor] = None, **kwargs: Any, @@ -478,7 +478,10 @@ def _parseq( model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index e5e557c232..0776414c7a 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, Sequential, layers +from keras import Model, Sequential, layers from doctr.datasets import VOCABS from doctr.utils.repr import NestedObject @@ -24,7 +24,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0", }, } @@ -394,7 +394,10 @@ def _sar( model = SAR(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 9c5359dde2..985f49a470 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS @@ -23,14 +23,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0", }, "vitstr_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0", }, } @@ -218,7 +218,10 @@ def _vitstr( model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 4c6f02c2a3..51a2bc69a5 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -4,13 +4,11 @@ # See LICENSE or go to for full license details. import logging -import os from typing import Any, Callable, List, Optional, Tuple, Union -from zipfile import ZipFile import tensorflow as tf import tf2onnx -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.utils.data import download_from_url @@ -40,22 +38,20 @@ def load_pretrained_params( model: Model, url: Optional[str] = None, hash_prefix: Optional[str] = None, - overwrite: bool = False, - internal_name: str = "weights", + skip_mismatch: bool = False, **kwargs: Any, ) -> None: """Load a set of parameters onto a model >>> from doctr.models import load_pretrained_params - >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") + >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5") Args: ---- model: the keras model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash - overwrite: should the zip extraction be enforced if the archive has already been extracted - internal_name: name of the ckpt files + skip_mismatch: skip loading layers with mismatched shapes **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url` """ if url is None: @@ -63,14 +59,12 @@ def load_pretrained_params( else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) - # Unzip the archive - params_path = archive_path.parent.joinpath(archive_path.stem) - if not params_path.is_dir() or overwrite: - with ZipFile(archive_path, "r") as f: - f.extractall(path=params_path) + # Build the model + # NOTE: `model.build` is not an option because it doesn't runs in eager mode + _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) # Load weights - model.load_weights(f"{params_path}{os.sep}{internal_name}") + model.load_weights(archive_path, skip_mismatch=skip_mismatch) def conv_sequence( @@ -83,7 +77,7 @@ def conv_sequence( ) -> List[layers.Layer]: """Builds a convolutional-based layer sequence - >>> from tensorflow.keras import Sequential + >>> from keras import Sequential >>> from doctr.models import conv_sequence >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) @@ -119,7 +113,7 @@ def conv_sequence( class IntermediateLayerGetter(Model): """Implements an intermediate layer getter - >>> from tensorflow.keras.applications import ResNet50 + >>> from keras.applications import ResNet50 >>> from doctr.models import IntermediateLayerGetter >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) diff --git a/pyproject.toml b/pyproject.toml index c0b209f535..aa0e02f98e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ tf = [ # cf. https://github.com/mindee/doctr/pull/1461 "tensorflow>=2.11.0,<2.16.0", "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 + # TODO: This is a temporary fix until we can upgrade to a newer version of tensorflow + "numpy>=1.16.0,<2.0.0", ] torch = [ "torch>=1.12.0,<3.0.0", @@ -158,6 +160,7 @@ implicit_reexport = false module = [ "anyascii.*", "tensorflow.*", + "keras.*", "torchvision.*", "onnxruntime.*", "PIL.*", @@ -195,7 +198,7 @@ ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408"] [tool.ruff.lint.isort] known-first-party = ["doctr", "app", "utils"] -known-third-party = ["tensorflow", "torch", "torchvision", "wandb", "tqdm", "fastapi", "onnxruntime", "cv2"] +known-third-party = ["tensorflow", "keras", "torch", "torchvision", "wandb", "tqdm", "fastapi", "onnxruntime", "cv2"] [tool.ruff.lint.per-file-ignores] "doctr/models/**.py" = ["N806", "F841"] diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index 580cf6fb1b..b2d24f2dbf 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -30,7 +30,7 @@ def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -176,6 +176,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ @@ -227,14 +229,14 @@ def main(args): return # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (1e3), # final lr as a fraction of initial lr staircase=False, name="ExponentialDecay", ) - optimizer = tf.keras.optimizers.Adam( + optimizer = optimizers.Adam( learning_rate=scheduler, beta_1=0.95, beta_2=0.99, @@ -291,7 +293,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index ad25713df7..e063174944 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -44,7 +44,7 @@ def rnd_rotate(img: tf.Tensor, target): def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -187,6 +187,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, *input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ @@ -237,14 +239,14 @@ def main(args): return # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (1e3), # final lr as a fraction of initial lr staircase=False, name="ExponentialDecay", ) - optimizer = tf.keras.optimizers.Adam( + optimizer = optimizers.Adam( learning_rate=scheduler, beta_1=0.95, beta_2=0.99, @@ -301,7 +303,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index 139932f2c4..abf012ed83 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -14,7 +14,7 @@ from pathlib import Path import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.experimental.list_physical_devices("GPU") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 1312a6ea13..b9c14494ad 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -14,7 +14,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -31,7 +31,7 @@ def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -58,7 +58,7 @@ def record_lr( # Forward, Backward & update with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: @@ -90,7 +90,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): images = batch_transforms(images) with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) @@ -107,7 +107,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_iter = iter(val_loader) for images, targets in tqdm(val_iter): images = batch_transforms(images) - out = model(images, targets, training=False, return_preds=True) + out = model(images, target=targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] for target, loc_pred in zip(targets, loc_preds): @@ -184,6 +184,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) if isinstance(args.pretrained_backbone, str): @@ -278,7 +280,7 @@ def main(args): # Scheduler if args.sched == "exponential": - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (25e4), # final lr as a fraction of initial lr @@ -286,7 +288,7 @@ def main(args): name="ExponentialDecay", ) elif args.sched == "poly": - scheduler = tf.keras.optimizers.schedules.PolynomialDecay( + scheduler = optimizers.schedules.PolynomialDecay( args.lr, decay_steps=args.epochs * len(train_loader), end_learning_rate=1e-7, @@ -295,7 +297,7 @@ def main(args): name="PolynomialDecay", ) # Optimizer - optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) + optimizer = optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) # LR Finder @@ -351,11 +353,11 @@ def main(args): val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss if args.save_interval_epoch: print(f"Saving state at epoch: {epoch + 1}") - model.save_weights(f"./{exp_name}_{epoch + 1}/weights") + model.save_weights(f"./{exp_name}_{epoch + 1}.weights.h5") log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " if any(val is None for val in (recall, precision, mean_iou)): log_msg += "(Undefined metric value, caused by empty GTs or predictions)" diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index 62651245c4..4c9d125285 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -11,7 +11,7 @@ import time import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.experimental.list_physical_devices("GPU") diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index 7f55142859..c76355a2f2 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -15,7 +15,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -32,7 +32,7 @@ def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -59,7 +59,7 @@ def record_lr( # Forward, Backward & update with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: @@ -91,7 +91,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): images = batch_transforms(images) with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) @@ -108,7 +108,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_iter = iter(val_loader) for images, targets in tqdm(val_iter): images = batch_transforms(images) - out = model(images, targets, return_preds=True, training=False) + out = model(images, target=targets, return_preds=True, training=False) # Compute metric if len(out["preds"]): words, _ = zip(*out["preds"]) @@ -184,6 +184,8 @@ def main(args): ) # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, args.input_size, 4 * args.input_size, 3)), training=False) model.load_weights(args.resume) # Metrics @@ -275,14 +277,14 @@ def main(args): return # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (25e4), # final lr as a fraction of initial lr staircase=False, name="ExponentialDecay", ) - optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) + optimizer = optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) # LR Finder @@ -343,7 +345,7 @@ def main(args): val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss print( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index 4c0b571da9..b3d25af173 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -54,7 +54,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): model = classification.__dict__[arch_name](pretrained=True).eval() _test_classification(model, input_shape, output_size) # Check that you can pretrained everything up until the last layer - classification.__dict__[arch_name](pretrained=True, num_classes=10) + assert classification.__dict__[arch_name](pretrained=True, num_classes=10) @pytest.mark.parametrize( diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 731e4dbd8b..11f4ea4114 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -2,6 +2,7 @@ import tempfile import cv2 +import keras import numpy as np import onnxruntime import psutil @@ -37,7 +38,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) # Forward out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32)) @@ -45,6 +46,11 @@ def test_classification_architectures(arch_name, input_shape, output_size): assert isinstance(out, tf.Tensor) assert out.dtype == tf.float32 assert out.numpy().shape == (batch_size, *output_size) + # Check that you can load pretrained up to the classification layer with differing number of classes to fine-tune + keras.backend.clear_session() + assert classification.__dict__[arch_name]( + pretrained=True, include_top=True, input_shape=input_shape, num_classes=10 + ) @pytest.mark.parametrize( @@ -57,7 +63,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): def test_classification_models(arch_name, input_shape): batch_size = 8 reco_model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, tf.keras.Model) + assert isinstance(reco_model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) out = reco_model(input_tensor) @@ -226,7 +232,7 @@ def test_page_orientation_model(mock_payslip): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if "orientation" in arch_name: model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) else: diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 2e627b9e4d..ba5f50542b 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -2,6 +2,7 @@ import os import tempfile +import keras import numpy as np import onnxruntime import psutil @@ -37,13 +38,13 @@ ) def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, input_shape=input_shape)) train_mode = False # Reparameterized model is not trainable else: model = detection.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(model, tf.keras.Model) + assert isinstance(model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = [ {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, @@ -152,7 +153,7 @@ def test_rotated_detectionpredictor(mock_pdf): ) def test_detection_zoo(arch_name): # Model - tf.keras.backend.clear_session() + keras.backend.clear_session() predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) # object check assert isinstance(predictor, DetectionPredictor) @@ -177,7 +178,7 @@ def test_fast_reparameterization(): base_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(base_model_params, 13535296) # base model params base_out = base_model(dummy_input, training=False)["logits"] - tf.keras.backend.clear_session() + keras.backend.clear_session() rep_model = reparameterize(base_model) rep_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(rep_model_params, 8520256) # reparameterized model params @@ -241,7 +242,7 @@ def test_dilate(): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True, input_shape=input_shape)) else: diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py index 9b1ad2e166..0860d8612c 100644 --- a/tests/tensorflow/test_models_factory.py +++ b/tests/tensorflow/test_models_factory.py @@ -2,8 +2,8 @@ import os import tempfile +import keras import pytest -import tensorflow as tf from doctr import models from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub @@ -25,40 +25,39 @@ def test_push_to_hf_hub(): @pytest.mark.parametrize( "arch_name, task_name, dummy_model_id", [ - ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-tf-vgg16-bn-r"], - ["resnet18", "classification", "Felix92/doctr-dummy-tf-resnet18"], - ["resnet31", "classification", "Felix92/doctr-dummy-tf-resnet31"], - ["resnet34", "classification", "Felix92/doctr-dummy-tf-resnet34"], - ["resnet34_wide", "classification", "Felix92/doctr-dummy-tf-resnet34-wide"], - ["resnet50", "classification", "Felix92/doctr-dummy-tf-resnet50"], - ["magc_resnet31", "classification", "Felix92/doctr-dummy-tf-magc-resnet31"], - ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-tf-mobilenet-v3-large"], - ["vit_b", "classification", "Felix92/doctr-dummy-tf-vit-b"], - ["textnet_tiny", "classification", "Felix92/doctr-dummy-tf-textnet-tiny"], - ["db_resnet50", "detection", "Felix92/doctr-dummy-tf-db-resnet50"], - ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-tf-db-mobilenet-v3-large"], - ["linknet_resnet18", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18"], - ["linknet_resnet34", "detection", "Felix92/doctr-dummy-tf-linknet-resnet34"], - ["linknet_resnet50", "detection", "Felix92/doctr-dummy-tf-linknet-resnet50"], - ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-tf-crnn-vgg16-bn"], - ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-tf-crnn-mobilenet-v3-large"], - ["sar_resnet31", "recognition", "Felix92/doctr-dummy-tf-sar-resnet31"], - ["master", "recognition", "Felix92/doctr-dummy-tf-master"], - ["vitstr_small", "recognition", "Felix92/doctr-dummy-tf-vitstr-small"], - ["parseq", "recognition", "Felix92/doctr-dummy-tf-parseq"], + ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-tf-vgg16-bn-r-v2"], + ["resnet18", "classification", "Felix92/doctr-dummy-tf-resnet18-v2"], + ["resnet31", "classification", "Felix92/doctr-dummy-tf-resnet31-v2"], + ["resnet34", "classification", "Felix92/doctr-dummy-tf-resnet34-v2"], + ["resnet34_wide", "classification", "Felix92/doctr-dummy-tf-resnet34-wide-v2"], + ["resnet50", "classification", "Felix92/doctr-dummy-tf-resnet50-v2"], + ["magc_resnet31", "classification", "Felix92/doctr-dummy-tf-magc-resnet31-v2"], + ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-tf-mobilenet-v3-large-v2"], + ["vit_b", "classification", "Felix92/doctr-dummy-tf-vit-b-v2"], + ["textnet_tiny", "classification", "Felix92/doctr-dummy-tf-textnet-tiny-v2"], + ["db_resnet50", "detection", "Felix92/doctr-dummy-tf-db-resnet50-v2"], + ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-tf-db-mobilenet-v3-large-v2"], + ["linknet_resnet18", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18-v2"], + ["linknet_resnet50", "detection", "Felix92/doctr-dummy-tf-linknet-resnet50-v2"], + ["linknet_resnet34", "detection", "Felix92/doctr-dummy-tf-linknet-resnet34-v2"], + ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-tf-crnn-vgg16-bn-v2"], + ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-tf-crnn-mobilenet-v3-large-v2"], + ["sar_resnet31", "recognition", "Felix92/doctr-dummy-tf-sar-resnet31-v2"], + ["master", "recognition", "Felix92/doctr-dummy-tf-master-v2"], + ["vitstr_small", "recognition", "Felix92/doctr-dummy-tf-vitstr-small-v2"], + ["parseq", "recognition", "Felix92/doctr-dummy-tf-parseq-v2"], ], ) def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): with tempfile.TemporaryDirectory() as tmp_dir: - tf.keras.backend.clear_session() + keras.backend.clear_session() model = models.__dict__[task_name].__dict__[arch_name](pretrained=True) _save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir) assert hasattr(model, "cfg") assert len(os.listdir(tmp_dir)) == 2 - assert os.path.exists(tmp_dir + "/tf_model") - assert len(os.listdir(tmp_dir + "/tf_model")) == 3 + assert os.path.exists(tmp_dir + "/tf_model.weights.h5") assert os.path.exists(tmp_dir + "/config.json") tmp_config = json.load(open(tmp_dir + "/config.json")) assert arch_name == tmp_config["arch"] @@ -66,6 +65,6 @@ def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): assert all(key in model.cfg.keys() for key in tmp_config.keys()) # test from hub - tf.keras.backend.clear_session() + keras.backend.clear_session() hub_model = from_hub(repo_id=dummy_model_id) assert isinstance(hub_model, type(model)) diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py index b58272d1de..7da1cb534a 100644 --- a/tests/tensorflow/test_models_recognition_tf.py +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -2,6 +2,7 @@ import shutil import tempfile +import keras import numpy as np import onnxruntime import psutil @@ -37,10 +38,10 @@ ["parseq", (32, 128, 3)], ], ) -def test_recognition_models(arch_name, input_shape, train_mode): +def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): batch_size = 4 - reco_model = recognition.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, tf.keras.Model) + reco_model = recognition.__dict__[arch_name](vocab=mock_vocab, pretrained=True, input_shape=input_shape) + assert isinstance(reco_model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = ["i", "am", "a", "jedi"] @@ -194,7 +195,7 @@ def test_recognition_zoo_error(): def test_models_onnx_export(arch_name, input_shape): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() model = recognition.__dict__[arch_name](pretrained=True, exportable=True, input_shape=input_shape) # SAR, MASTER, ViTSTR export currently only available with constant batch size if arch_name in ["sar_resnet31", "master", "vitstr_small", "parseq"]: diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py index b83e60c0ee..b57b41b14b 100644 --- a/tests/tensorflow/test_models_utils_tf.py +++ b/tests/tensorflow/test_models_utils_tf.py @@ -2,9 +2,9 @@ import pytest import tensorflow as tf -from tensorflow.keras import Sequential, layers -from tensorflow.keras.applications import ResNet50 +from keras.applications import ResNet50 +from doctr.models.classification import mobilenet_v3_small from doctr.models.utils import ( IntermediateLayerGetter, _bf16_to_float32, @@ -27,20 +27,18 @@ def test_bf16_to_float32(): def test_load_pretrained_params(tmpdir_factory): - model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)]) + model = mobilenet_v3_small(pretrained=False) # Retrieve this URL - url = "https://doctr-static.mindee.com/models?id=v0.1-models/tmp_checkpoint-4a98e492.zip&src=0" + url = "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0" # Temp cache dir cache_dir = tmpdir_factory.mktemp("cache") # Pass an incorrect hash with pytest.raises(ValueError): - load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir), internal_name="") + load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir)) # Let tit resolve the hash from the file name - load_pretrained_params(model, url, cache_dir=str(cache_dir), internal_name="") - # Check that the file was downloaded & the archive extracted - assert os.path.exists(cache_dir.join("models").join("tmp_checkpoint-4a98e492")) - # Check that archive was deleted - assert os.path.exists(cache_dir.join("models").join("tmp_checkpoint-4a98e492.zip")) + load_pretrained_params(model, url, cache_dir=str(cache_dir)) + # Check that the file was downloaded + assert os.path.exists(cache_dir.join("models").join("mobilenet_v3_small-3fcebad7.weights.h5")) def test_conv_sequence(): From 7f6757c968432f25e0358124f0926ef6a33bcf8d Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 1 Oct 2024 10:42:02 +0200 Subject: [PATCH 10/18] [datasets] Allow detection task for built-in datasets (#1717) --- Makefile | 2 +- docs/source/using_doctr/using_datasets.rst | 19 +- doctr/datasets/cord.py | 11 +- doctr/datasets/funsd.py | 12 +- doctr/datasets/ic03.py | 12 +- doctr/datasets/ic13.py | 11 +- doctr/datasets/iiit5k.py | 42 ++-- doctr/datasets/imgur5k.py | 11 +- doctr/datasets/sroie.py | 12 +- doctr/datasets/svhn.py | 12 +- doctr/datasets/svt.py | 12 +- doctr/datasets/synthtext.py | 12 +- doctr/datasets/utils.py | 9 +- doctr/datasets/wildreceipt.py | 13 +- references/detection/evaluate_pytorch.py | 14 +- references/detection/evaluate_tensorflow.py | 14 +- tests/pytorch/test_datasets_pt.py | 251 ++++++++++++++++---- tests/tensorflow/test_datasets_tf.py | 242 +++++++++++++++---- 18 files changed, 586 insertions(+), 125 deletions(-) diff --git a/Makefile b/Makefile index 428bc4fc4a..04662b9613 100644 --- a/Makefile +++ b/Makefile @@ -6,8 +6,8 @@ quality: # this target runs checks on all files and potentially modifies some of them style: - ruff check --fix . ruff format . + ruff check --fix . # Run tests for the library test: diff --git a/docs/source/using_doctr/using_datasets.rst b/docs/source/using_doctr/using_datasets.rst index 52c5f7e24d..5fd5dc2776 100644 --- a/docs/source/using_doctr/using_datasets.rst +++ b/docs/source/using_doctr/using_datasets.rst @@ -48,9 +48,9 @@ This datasets contains the information to train or validate a text detection mod from doctr.datasets import CORD # Load straight boxes - train_set = CORD(train=True, download=True) + train_set = CORD(train=True, download=True, detection_task=True) # Load rotated boxes - train_set = CORD(train=True, download=True, use_polygons=True) + train_set = CORD(train=True, download=True, use_polygons=True, detection_task=True) img, target = train_set[0] @@ -99,6 +99,21 @@ This datasets contains the information to train or validate a text recognition m img, target = train_set[0] +OCR +^^^ + +The same dataset table as for detection, but with information about the bounding boxes and labels. + +.. code:: python3 + + from doctr.datasets import CORD + # Load straight boxes + train_set = CORD(train=True, download=True) + # Load rotated boxes + train_set = CORD(train=True, download=True, use_polygons=True) + img, target = train_set[0] + + Object Detection ^^^^^^^^^^^^^^^^ diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py index b88fbb28e8..9e2188727d 100644 --- a/doctr/datasets/cord.py +++ b/doctr/datasets/cord.py @@ -33,6 +33,7 @@ class CORD(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -53,6 +54,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, name = self.TRAIN if train else self.TEST @@ -64,10 +66,15 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) # List images tmp_root = os.path.join(self.root, "image") - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] self.train = train np_dtype = np.float32 for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))): @@ -109,6 +116,8 @@ def __init__( ) for crop, label in zip(crops, list(text_targets)): self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0))) else: self.data.append(( img_path, diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py index 0580b473a7..3bd8b088f9 100644 --- a/doctr/datasets/funsd.py +++ b/doctr/datasets/funsd.py @@ -33,6 +33,7 @@ class FUNSD(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -45,6 +46,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -55,6 +57,12 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train np_dtype = np.float32 @@ -63,7 +71,7 @@ def __init__( # # List images tmp_root = os.path.join(self.root, subfolder, "images") - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))): # File existence check if not os.path.exists(os.path.join(tmp_root, img_path)): @@ -100,6 +108,8 @@ def __init__( # filter labels with unknown characters if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]): self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype))) else: self.data.append(( img_path, diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py index 6f080e4d45..b3af8d958c 100644 --- a/doctr/datasets/ic03.py +++ b/doctr/datasets/ic03.py @@ -32,6 +32,7 @@ class IC03(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -51,6 +52,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, file_name = self.TRAIN if train else self.TEST @@ -62,8 +64,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 # Load xml data @@ -117,6 +125,8 @@ def __init__( for crop, label in zip(crops, labels): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((name.text, boxes)) else: self.data.append((name.text, dict(boxes=boxes, labels=labels))) diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py index 81ba62f001..0082d92316 100644 --- a/doctr/datasets/ic13.py +++ b/doctr/datasets/ic13.py @@ -38,6 +38,7 @@ class IC13(AbstractDataset): label_folder: folder with all annotation files for the images use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `AbstractDataset`. """ @@ -47,11 +48,17 @@ def __init__( label_folder: str, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) # File existence check if not os.path.exists(label_folder) or not os.path.exists(img_folder): @@ -59,7 +66,7 @@ def __init__( f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}" ) - self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 img_names = os.listdir(img_folder) @@ -95,5 +102,7 @@ def __init__( crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets) for crop, label in zip(crops, labels): self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, box_targets)) else: self.data.append((img_path, dict(boxes=box_targets, labels=labels))) diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py index 2b33ebb50b..89619dd8aa 100644 --- a/doctr/datasets/iiit5k.py +++ b/doctr/datasets/iiit5k.py @@ -34,6 +34,7 @@ class IIIT5K(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -45,6 +46,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -55,6 +57,12 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train # Load mat data @@ -62,7 +70,7 @@ def __init__( mat_file = "trainCharBound" if self.train else "testCharBound" mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0] - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)): @@ -73,24 +81,26 @@ def __init__( if not os.path.exists(os.path.join(tmp_root, _raw_path)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}") + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [ + [ + [box[0], box[1]], + [box[0] + box[2], box[1]], + [box[0] + box[2], box[1] + box[3]], + [box[0], box[1] + box[3]], + ] + for box in box_targets + ] + else: + # xmin, ymin, xmax, ymax + box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] + if recognition_task: self.data.append((_raw_path, _raw_label)) + elif detection_task: + self.data.append((_raw_path, np.asarray(box_targets, dtype=np_dtype))) else: - if use_polygons: - # (x, y) coordinates of top left, top right, bottom right, bottom left corners - box_targets = [ - [ - [box[0], box[1]], - [box[0] + box[2], box[1]], - [box[0] + box[2], box[1] + box[3]], - [box[0], box[1] + box[3]], - ] - for box in box_targets - ] - else: - # xmin, ymin, xmax, ymax - box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] - # label are casted to list where each char corresponds to the character's bounding box self.data.append(( _raw_path, diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py index 3e7cf0e07b..4dcfec02b8 100644 --- a/doctr/datasets/imgur5k.py +++ b/doctr/datasets/imgur5k.py @@ -46,6 +46,7 @@ class IMGUR5K(AbstractDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `AbstractDataset`. """ @@ -56,17 +57,23 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) # File existence check if not os.path.exists(label_path) or not os.path.exists(img_folder): raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") - self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] self.train = train np_dtype = np.float32 @@ -132,6 +139,8 @@ def __init__( tmp_img = Image.fromarray(crop) tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png")) reco_images_counter += 1 + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype))) else: self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels))) diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py index e72fde68a1..d6e7dac83b 100644 --- a/doctr/datasets/sroie.py +++ b/doctr/datasets/sroie.py @@ -33,6 +33,7 @@ class SROIE(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -52,6 +53,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, name = self.TRAIN if train else self.TEST @@ -63,10 +65,16 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train tmp_root = os.path.join(self.root, "images") - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))): @@ -94,6 +102,8 @@ def __init__( for crop, label in zip(crops, labels): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, coords)) else: self.data.append((img_path, dict(boxes=coords, labels=labels))) diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py index 57085c5213..595113a42d 100644 --- a/doctr/datasets/svhn.py +++ b/doctr/datasets/svhn.py @@ -32,6 +32,7 @@ class SVHN(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -52,6 +53,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: url, sha256, name = self.TRAIN if train else self.TEST @@ -63,8 +65,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 tmp_root = os.path.join(self.root, "train" if train else "test") @@ -122,6 +130,8 @@ def __init__( for crop, label in zip(crops, label_targets): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((img_name, box_targets)) else: self.data.append((img_name, dict(boxes=box_targets, labels=label_targets))) diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py index 3eb7b6d599..b9e88b4cc1 100644 --- a/doctr/datasets/svt.py +++ b/doctr/datasets/svt.py @@ -32,6 +32,7 @@ class SVT(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -43,6 +44,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -53,8 +55,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 # Load xml data @@ -108,6 +116,8 @@ def __init__( for crop, label in zip(crops, labels): if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: self.data.append((crop, label)) + elif detection_task: + self.data.append((name.text, boxes)) else: self.data.append((name.text, dict(boxes=boxes, labels=labels))) diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py index a60e22e832..8be11e2303 100644 --- a/doctr/datasets/synthtext.py +++ b/doctr/datasets/synthtext.py @@ -35,6 +35,7 @@ class SynthText(VisionDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """ @@ -46,6 +47,7 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -56,8 +58,14 @@ def __init__( pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs, ) + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + self.train = train - self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] np_dtype = np.float32 # Load mat data @@ -111,6 +119,8 @@ def __init__( tmp_img = Image.fromarray(crop) tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png")) reco_images_counter += 1 + elif detection_task: + self.data.append((img_path[0], np.asarray(word_boxes, dtype=np_dtype))) else: self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels))) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index 860e19a229..75182a227a 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -169,8 +169,13 @@ def encode_sequences( return encoded_data -def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]: - target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) +def convert_target_to_relative( + img: ImageTensor, target: Union[np.ndarray, Dict[str, Any]] +) -> Tuple[ImageTensor, Union[Dict[str, Any], np.ndarray]]: + if isinstance(target, np.ndarray): + target = convert_to_relative_coords(target, get_img_shape(img)) + else: + target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) return img, target diff --git a/doctr/datasets/wildreceipt.py b/doctr/datasets/wildreceipt.py index 19108d7761..685266931a 100644 --- a/doctr/datasets/wildreceipt.py +++ b/doctr/datasets/wildreceipt.py @@ -40,6 +40,7 @@ class WILDRECEIPT(AbstractDataset): train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task + detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `AbstractDataset`. """ @@ -50,11 +51,19 @@ def __init__( train: bool = True, use_polygons: bool = False, recognition_task: bool = False, + detection_task: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs ) + # Task check + if recognition_task and detection_task: + raise ValueError( + "`recognition_task` and `detection_task` cannot be set to True simultaneously. " + + "To get the whole dataset with boxes and labels leave both parameters to False." + ) + # File existence check if not os.path.exists(label_path) or not os.path.exists(img_folder): raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") @@ -62,7 +71,7 @@ def __init__( tmp_root = img_folder self.train = train np_dtype = np.float32 - self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any], np.ndarray]]] = [] with open(label_path, "r") as file: data = file.read() @@ -100,6 +109,8 @@ def __init__( for crop, label in zip(crops, list(text_targets)): if label and " " not in label: self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0))) else: self.data.append(( img_path, diff --git a/references/detection/evaluate_pytorch.py b/references/detection/evaluate_pytorch.py index 15f60df664..10b20e40cc 100644 --- a/references/detection/evaluate_pytorch.py +++ b/references/detection/evaluate_pytorch.py @@ -37,7 +37,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) - targets = [{CLASS_NAME: t["boxes"]} for t in targets] + targets = [{CLASS_NAME: t} for t in targets] if amp: with torch.cuda.amp.autocast(): out = model(images, targets, return_preds=True) @@ -82,7 +82,10 @@ def main(args): train=True, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape), + detection_task=True, + sample_transforms=T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) # Monkeypatch subfolder = ds.root.split("/")[-2:] @@ -92,7 +95,10 @@ def main(args): train=False, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape), + detection_task=True, + sample_transforms=T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) subfolder = _ds.root.split("/")[-2:] ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) @@ -155,6 +161,8 @@ def parse_args(): parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") parser.add_argument("--device", default=None, type=int, help="device") parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox") parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index abf012ed83..4eef9a40b7 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -35,7 +35,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_loss, batch_cnt = 0, 0 for images, targets in tqdm(val_loader): images = batch_transforms(images) - targets = [{CLASS_NAME: t["boxes"]} for t in targets] + targets = [{CLASS_NAME: t} for t in targets] out = model(images, targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] @@ -81,7 +81,10 @@ def main(args): train=True, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape[:2]), + detection_task=True, + sample_transforms=T.Resize( + input_shape[:2], preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) # Monkeypatch subfolder = ds.root.split("/")[-2:] @@ -91,7 +94,10 @@ def main(args): train=False, download=True, use_polygons=args.rotation, - sample_transforms=T.Resize(input_shape[:2]), + detection_task=True, + sample_transforms=T.Resize( + input_shape[:2], preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), ) subfolder = _ds.root.split("/")[-2:] ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) @@ -129,6 +135,8 @@ def parse_args(): parser.add_argument("--dataset", type=str, default="FUNSD", help="Dataset to evaluate on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox") parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index 749a86bf06..30f9e6f288 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -72,6 +72,36 @@ def _validate_dataset_recognition_part(ds, input_size, batch_size=2): assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) +def _validate_dataset_detection_part(ds, input_size, batch_size=2, is_polygons=False): + # Fetch one sample + img, target = ds[0] + + assert isinstance(img, torch.Tensor) + assert img.shape == (3, *input_size) + assert img.dtype == torch.float32 + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + if is_polygons: + assert target.ndim == 3 and target.shape[1:] == (4, 2) + else: + assert target.ndim == 2 and target.shape[1:] == (4,) + assert np.all(np.logical_and(target <= 1, target >= 0)) + + # Check batching + loader = DataLoader( + ds, + batch_size=batch_size, + drop_last=True, + sampler=RandomSampler(ds), + num_workers=0, + pin_memory=True, + collate_fn=ds.collate_fn, + ) + + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (batch_size, 3, *input_size) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + def test_visiondataset(): url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" with pytest.raises(ValueError): @@ -282,13 +312,14 @@ def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts) @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 626 training samples and 360 test samples - [[32, 128], 15, True], # recognition + [[512, 512], 3, False, False], # Actual set has 626 training samples and 360 test samples + [[32, 128], 15, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset): +def test_sroie(input_size, num_samples, rotate, recognition, detection, mock_sroie_dataset): # monkeypatch the path to temporary dataset datasets.SROIE.TRAIN = (mock_sroie_dataset, None, "sroie2019_train_task1.zip") @@ -298,6 +329,7 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), cache_subdir=mock_sroie_dataset.split("/")[-2], ) @@ -306,67 +338,94 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) assert repr(ds) == f"SROIE(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SROIE( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), + cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 5, False], # Actual set has 229 train and 233 test samples - [[32, 128], 25, True], # recognition + [[512, 512], 5, False, False], # Actual set has 229 train and 233 test samples + [[32, 128], 25, True, False], # recognition + [[512, 512], 5, False, True], # detection ], ) -def test_ic13_dataset(input_size, num_samples, rotate, recognition, mock_ic13): +def test_ic13_dataset(input_size, num_samples, rotate, recognition, detection, mock_ic13): ds = datasets.IC13( *mock_ic13, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC13(*mock_ic13, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 7149 train and 796 test samples - [[32, 128], 5, True], # recognition + [[512, 512], 3, False, False], # Actual set has 7149 train and 796 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, mock_imgur5k): +def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, detection, mock_imgur5k): ds = datasets.IMGUR5K( *mock_imgur5k, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split assert repr(ds) == f"IMGUR5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IMGUR5K(*mock_imgur5k, train=True, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 3, False], # Actual set has 33402 training samples and 13068 test samples - [[32, 128], 12, True], # recognition + [[32, 128], 3, False, False], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 12, True, False], # recognition + [[32, 128], 3, False, True], # detection ], ) -def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): +def test_svhn(input_size, num_samples, rotate, recognition, detection, mock_svhn_dataset): # monkeypatch the path to temporary dataset datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") @@ -376,6 +435,7 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), cache_subdir=mock_svhn_dataset.split("/")[-2], ) @@ -384,19 +444,32 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): assert repr(ds) == f"SVHN(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVHN( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), + cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 149 training samples and 50 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 149 training samples and 50 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset): +def test_funsd(input_size, num_samples, rotate, recognition, detection, mock_funsd_dataset): # monkeypatch the path to temporary dataset datasets.FUNSD.URL = mock_funsd_dataset datasets.FUNSD.SHA256 = None @@ -408,6 +481,7 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), cache_subdir=mock_funsd_dataset.split("/")[-2], ) @@ -416,19 +490,32 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) assert repr(ds) == f"FUNSD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.FUNSD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), + cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 800 training samples and 100 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 800 training samples and 100 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): +def test_cord(input_size, num_samples, rotate, recognition, detection, mock_cord_dataset): # monkeypatch the path to temporary dataset datasets.CORD.TRAIN = (mock_cord_dataset, None, "cord_train.zip") @@ -438,6 +525,7 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), cache_subdir=mock_cord_dataset.split("/")[-2], ) @@ -446,19 +534,32 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): assert repr(ds) == f"CORD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.CORD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), + cache_subdir=mock_cord_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], # Actual set has 772875 training samples and 85875 test samples - [[32, 128], 10, True], # recognition + [[512, 512], 2, False, False], # Actual set has 772875 training samples and 85875 test samples + [[32, 128], 10, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_dataset): +def test_synthtext(input_size, num_samples, rotate, recognition, detection, mock_synthtext_dataset): # monkeypatch the path to temporary dataset datasets.SynthText.URL = mock_synthtext_dataset datasets.SynthText.SHA256 = None @@ -469,6 +570,7 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), cache_subdir=mock_synthtext_dataset.split("/")[-2], ) @@ -477,19 +579,32 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ assert repr(ds) == f"SynthText(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SynthText( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), + cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 1, False], # Actual set has 2000 training samples and 3000 test samples - [[32, 128], 1, True], # recognition + [[32, 128], 1, False, False], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, True, False], # recognition + [[32, 128], 1, False, True], # detection ], ) -def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_dataset): +def test_iiit5k(input_size, num_samples, rotate, recognition, detection, mock_iiit5k_dataset): # monkeypatch the path to temporary dataset datasets.IIIT5K.URL = mock_iiit5k_dataset datasets.IIIT5K.SHA256 = None @@ -500,6 +615,7 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), cache_subdir=mock_iiit5k_dataset.split("/")[-2], ) @@ -508,19 +624,32 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase assert repr(ds) == f"IIIT5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size, batch_size=1) + elif detection: + _validate_dataset_detection_part(ds, input_size, batch_size=1, is_polygons=rotate) else: _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IIIT5K( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), + cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 100 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 100 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): +def test_svt(input_size, num_samples, rotate, recognition, detection, mock_svt_dataset): # monkeypatch the path to temporary dataset datasets.SVT.URL = mock_svt_dataset datasets.SVT.SHA256 = None @@ -531,6 +660,7 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), cache_subdir=mock_svt_dataset.split("/")[-2], ) @@ -539,19 +669,32 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): assert repr(ds) == f"SVT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVT( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), + cache_subdir=mock_svt_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 246 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 246 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): +def test_ic03(input_size, num_samples, rotate, recognition, detection, mock_ic03_dataset): # monkeypatch the path to temporary dataset datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") @@ -561,6 +704,7 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), cache_subdir=mock_ic03_dataset.split("/")[-2], ) @@ -569,33 +713,52 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): assert repr(ds) == f"IC03(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC03( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), + cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], - [[32, 128], 5, True], + [[512, 512], 2, False, False], # Actual set has 1268 training samples and 472 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset): +def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, detection, mock_wildreceipt_dataset): ds = datasets.WILDRECEIPT( *mock_wildreceipt_dataset, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples assert repr(ds) == f"WILDRECEIPT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.WILDRECEIPT(*mock_wildreceipt_dataset, train=True, recognition_task=True, detection_task=True) + # NOTE: following datasets are only for recognition task diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 5d6c61b116..1129b4264e 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -54,6 +54,27 @@ def _validate_dataset_recognition_part(ds, input_size, batch_size=2): assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) +def _validate_dataset_detection_part(ds, input_size, is_polygons=False, batch_size=2): + # Fetch one sample + img, target = ds[0] + assert isinstance(img, tf.Tensor) + assert img.shape == (*input_size, 3) + assert img.dtype == tf.float32 + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + if is_polygons: + assert target.ndim == 3 and target.shape[1:] == (4, 2) + else: + assert target.ndim == 2 and target.shape[1:] == (4,) + assert np.all(np.logical_and(target <= 1, target >= 0)) + + # Check batching + loader = DataLoader(ds, batch_size=batch_size) + + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (batch_size, *input_size, 3) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + def test_visiondataset(): url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" with pytest.raises(ValueError): @@ -264,13 +285,14 @@ def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts) @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 626 training samples and 360 test samples - [[32, 128], 15, True], # recognition + [[512, 512], 3, False, False], # Actual set has 626 training samples and 360 test samples + [[32, 128], 15, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset): +def test_sroie(input_size, num_samples, rotate, recognition, detection, mock_sroie_dataset): # monkeypatch the path to temporary dataset datasets.SROIE.TRAIN = (mock_sroie_dataset, None, "sroie2019_train_task1.zip") @@ -280,6 +302,7 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), cache_subdir=mock_sroie_dataset.split("/")[-2], ) @@ -288,67 +311,94 @@ def test_sroie(input_size, num_samples, rotate, recognition, mock_sroie_dataset) assert repr(ds) == f"SROIE(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SROIE( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), + cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 5, False], # Actual set has 229 train and 233 test samples - [[32, 128], 25, True], # recognition + [[512, 512], 5, False, False], # Actual set has 229 train and 233 test samples + [[32, 128], 25, True, False], # recognition + [[512, 512], 5, False, True], # detection ], ) -def test_ic13_dataset(input_size, num_samples, rotate, recognition, mock_ic13): +def test_ic13_dataset(input_size, num_samples, rotate, recognition, detection, mock_ic13): ds = datasets.IC13( *mock_ic13, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC13(*mock_ic13, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 7149 train and 796 test samples - [[32, 128], 5, True], # recognition + [[512, 512], 3, False, False], # Actual set has 7149 train and 796 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, mock_imgur5k): +def test_imgur5k_dataset(input_size, num_samples, rotate, recognition, detection, mock_imgur5k): ds = datasets.IMGUR5K( *mock_imgur5k, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split assert repr(ds) == f"IMGUR5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IMGUR5K(*mock_imgur5k, train=True, recognition_task=True, detection_task=True) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 3, False], # Actual set has 33402 training samples and 13068 test samples - [[32, 128], 12, True], # recognition + [[32, 128], 3, False, False], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 12, True, False], # recognition + [[32, 128], 3, False, True], # detection ], ) -def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): +def test_svhn(input_size, num_samples, rotate, recognition, detection, mock_svhn_dataset): # monkeypatch the path to temporary dataset datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") @@ -358,6 +408,7 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), cache_subdir=mock_svhn_dataset.split("/")[-2], ) @@ -366,19 +417,32 @@ def test_svhn(input_size, num_samples, rotate, recognition, mock_svhn_dataset): assert repr(ds) == f"SVHN(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVHN( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), + cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 149 training samples and 50 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 149 training samples and 50 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset): +def test_funsd(input_size, num_samples, rotate, recognition, detection, mock_funsd_dataset): # monkeypatch the path to temporary dataset datasets.FUNSD.URL = mock_funsd_dataset datasets.FUNSD.SHA256 = None @@ -390,6 +454,7 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), cache_subdir=mock_funsd_dataset.split("/")[-2], ) @@ -398,19 +463,32 @@ def test_funsd(input_size, num_samples, rotate, recognition, mock_funsd_dataset) assert repr(ds) == f"FUNSD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.FUNSD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), + cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 800 training samples and 100 test samples - [[32, 128], 9, True], # recognition + [[512, 512], 3, False, False], # Actual set has 800 training samples and 100 test samples + [[32, 128], 9, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): +def test_cord(input_size, num_samples, rotate, recognition, detection, mock_cord_dataset): # monkeypatch the path to temporary dataset datasets.CORD.TRAIN = (mock_cord_dataset, None, "cord_train.zip") @@ -420,6 +498,7 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), cache_subdir=mock_cord_dataset.split("/")[-2], ) @@ -428,19 +507,32 @@ def test_cord(input_size, num_samples, rotate, recognition, mock_cord_dataset): assert repr(ds) == f"CORD(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.CORD( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), + cache_subdir=mock_cord_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], # Actual set has 772875 training samples and 85875 test samples - [[32, 128], 10, True], # recognition + [[512, 512], 2, False, False], # Actual set has 772875 training samples and 85875 test samples + [[32, 128], 10, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_dataset): +def test_synthtext(input_size, num_samples, rotate, recognition, detection, mock_synthtext_dataset): # monkeypatch the path to temporary dataset datasets.SynthText.URL = mock_synthtext_dataset datasets.SynthText.SHA256 = None @@ -451,6 +543,7 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), cache_subdir=mock_synthtext_dataset.split("/")[-2], ) @@ -459,19 +552,32 @@ def test_synthtext(input_size, num_samples, rotate, recognition, mock_synthtext_ assert repr(ds) == f"SynthText(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SynthText( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), + cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[32, 128], 1, False], # Actual set has 2000 training samples and 3000 test samples - [[32, 128], 1, True], # recognition + [[32, 128], 1, False, False], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, True, False], # recognition + [[32, 128], 1, False, True], # detection ], ) -def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_dataset): +def test_iiit5k(input_size, num_samples, rotate, recognition, detection, mock_iiit5k_dataset): # monkeypatch the path to temporary dataset datasets.IIIT5K.URL = mock_iiit5k_dataset datasets.IIIT5K.SHA256 = None @@ -482,6 +588,7 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), cache_subdir=mock_iiit5k_dataset.split("/")[-2], ) @@ -490,19 +597,32 @@ def test_iiit5k(input_size, num_samples, rotate, recognition, mock_iiit5k_datase assert repr(ds) == f"IIIT5K(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size, batch_size=1) + elif detection: + _validate_dataset_detection_part(ds, input_size, batch_size=1, is_polygons=rotate) else: _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IIIT5K( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), + cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 100 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 100 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): +def test_svt(input_size, num_samples, rotate, recognition, detection, mock_svt_dataset): # monkeypatch the path to temporary dataset datasets.SVT.URL = mock_svt_dataset datasets.SVT.SHA256 = None @@ -513,6 +633,7 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), cache_subdir=mock_svt_dataset.split("/")[-2], ) @@ -521,19 +642,32 @@ def test_svt(input_size, num_samples, rotate, recognition, mock_svt_dataset): assert repr(ds) == f"SVT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.SVT( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), + cache_subdir=mock_svt_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 3, False], # Actual set has 246 training samples and 249 test samples - [[32, 128], 3, True], # recognition + [[512, 512], 3, False, False], # Actual set has 246 training samples and 249 test samples + [[32, 128], 3, True, False], # recognition + [[512, 512], 3, False, True], # detection ], ) -def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): +def test_ic03(input_size, num_samples, rotate, recognition, detection, mock_ic03_dataset): # monkeypatch the path to temporary dataset datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") @@ -543,6 +677,7 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), cache_subdir=mock_ic03_dataset.split("/")[-2], ) @@ -551,33 +686,52 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): assert repr(ds) == f"IC03(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.IC03( + train=True, + download=True, + recognition_task=True, + detection_task=True, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), + cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize( - "input_size, num_samples, recognition", + "input_size, num_samples, recognition, detection", [ - [[512, 512], 2, False], - [[32, 128], 5, True], + [[512, 512], 2, False, False], # Actual set has 1268 training samples and 472 test samples + [[32, 128], 5, True, False], # recognition + [[512, 512], 2, False, True], # detection ], ) -def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, mock_wildreceipt_dataset): +def test_wildreceipt_dataset(input_size, num_samples, rotate, recognition, detection, mock_wildreceipt_dataset): ds = datasets.WILDRECEIPT( *mock_wildreceipt_dataset, train=True, img_transforms=Resize(input_size), use_polygons=rotate, recognition_task=recognition, + detection_task=detection, ) assert len(ds) == num_samples assert repr(ds) == f"WILDRECEIPT(train={True})" if recognition: _validate_dataset_recognition_part(ds, input_size) + elif detection: + _validate_dataset_detection_part(ds, input_size, is_polygons=rotate) else: _validate_dataset(ds, input_size, is_polygons=rotate) + with pytest.raises(ValueError): + datasets.WILDRECEIPT(*mock_wildreceipt_dataset, train=True, recognition_task=True, detection_task=True) + # NOTE: following datasets are only for recognition task From da8888008070954c00899940b1e8a6089fc107c1 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 1 Oct 2024 10:42:24 +0200 Subject: [PATCH 11/18] [Bug] Fix eval scripts + possible overflow in Resize (#1715) --- api/app/vision.py | 2 +- doctr/transforms/modules/pytorch.py | 13 +++---- doctr/transforms/modules/tensorflow.py | 26 ++++++++------ .../classification/latency_tensorflow.py | 2 +- .../train_tensorflow_character.py | 2 +- .../train_tensorflow_orientation.py | 2 +- references/detection/evaluate_tensorflow.py | 2 +- references/detection/latency_tensorflow.py | 2 +- references/detection/train_tensorflow.py | 2 +- references/recognition/evaluate_tensorflow.py | 2 +- references/recognition/latency_tensorflow.py | 2 +- references/recognition/train_tensorflow.py | 2 +- scripts/analyze.py | 2 +- scripts/detect_text.py | 2 +- scripts/evaluate.py | 35 ++++++++++++++++--- scripts/evaluate_kie.py | 35 ++++++++++++++++--- tests/pytorch/test_transforms_pt.py | 16 +++++++++ tests/tensorflow/test_transforms_tf.py | 16 +++++++++ 18 files changed, 128 insertions(+), 37 deletions(-) diff --git a/api/app/vision.py b/api/app/vision.py index 005c8d1548..144b5e4c3b 100644 --- a/api/app/vision.py +++ b/api/app/vision.py @@ -6,7 +6,7 @@ import tensorflow as tf -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index f893afc2f7..639b27e2cf 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -74,16 +74,18 @@ def forward( if self.symmetric_pad: half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) + # Pad image img = pad(img, _pad) # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) if target is not None: + if self.symmetric_pad: + offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2] + if self.preserve_aspect_ratio: # Get absolute coords if target.shape[1:] == (4,): if isinstance(self.size, (tuple, list)) and self.symmetric_pad: - if np.max(target) <= 1: - offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2] target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1] target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2] else: @@ -91,16 +93,15 @@ def forward( target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2] elif target.shape[1:] == (4, 2): if isinstance(self.size, (tuple, list)) and self.symmetric_pad: - if np.max(target) <= 1: - offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2] target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1] target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2] else: target[..., 0] *= raw_shape[-1] / img.shape[-1] target[..., 1] *= raw_shape[-2] / img.shape[-2] else: - raise AssertionError - return img, target + raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)") + + return img, np.clip(target, 0, 1) return img diff --git a/doctr/transforms/modules/tensorflow.py b/doctr/transforms/modules/tensorflow.py index b3f7bcfd8a..4b00a9359f 100644 --- a/doctr/transforms/modules/tensorflow.py +++ b/doctr/transforms/modules/tensorflow.py @@ -107,29 +107,34 @@ def __call__( target: Optional[np.ndarray] = None, ) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]: input_dtype = img.dtype + self.output_size = ( + (self.output_size, self.output_size) if isinstance(self.output_size, int) else self.output_size + ) img = tf.image.resize(img, self.wanted_size, self.method, self.preserve_aspect_ratio, self.antialias) # It will produce an un-padded resized image, with a side shorter than wanted if we preserve aspect ratio raw_shape = img.shape[:2] + if self.symmetric_pad: + half_pad = (int((self.output_size[0] - img.shape[0]) / 2), 0) if self.preserve_aspect_ratio: if isinstance(self.output_size, (tuple, list)): # In that case we need to pad because we want to enforce both width and height if not self.symmetric_pad: - offset = (0, 0) + half_pad = (0, 0) elif self.output_size[0] == img.shape[0]: - offset = (0, int((self.output_size[1] - img.shape[1]) / 2)) - else: - offset = (int((self.output_size[0] - img.shape[0]) / 2), 0) - img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size) + half_pad = (0, int((self.output_size[1] - img.shape[1]) / 2)) + # Pad image + img = tf.image.pad_to_bounding_box(img, *half_pad, *self.output_size) # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) if target is not None: + if self.symmetric_pad: + offset = half_pad[0] / img.shape[0], half_pad[1] / img.shape[1] + if self.preserve_aspect_ratio: # Get absolute coords if target.shape[1:] == (4,): if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad: - if np.max(target) <= 1: - offset = offset[0] / img.shape[0], offset[1] / img.shape[1] target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1] target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0] else: @@ -137,16 +142,15 @@ def __call__( target[:, [1, 3]] *= raw_shape[0] / img.shape[0] elif target.shape[1:] == (4, 2): if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad: - if np.max(target) <= 1: - offset = offset[0] / img.shape[0], offset[1] / img.shape[1] target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1] target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0] else: target[..., 0] *= raw_shape[1] / img.shape[1] target[..., 1] *= raw_shape[0] / img.shape[0] else: - raise AssertionError - return tf.cast(img, dtype=input_dtype), target + raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)") + + return tf.cast(img, dtype=input_dtype), np.clip(target, 0, 1) return tf.cast(img, dtype=input_dtype) diff --git a/references/classification/latency_tensorflow.py b/references/classification/latency_tensorflow.py index fc010df91a..6ccdefac18 100644 --- a/references/classification/latency_tensorflow.py +++ b/references/classification/latency_tensorflow.py @@ -20,7 +20,7 @@ def main(args): if args.gpu: - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) else: diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index b2d24f2dbf..89d0165d90 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -18,7 +18,7 @@ from doctr.models import login_to_hub, push_to_hf_hub -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index e063174944..a7d3b96943 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -18,7 +18,7 @@ from doctr.models import login_to_hub, push_to_hf_hub -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index 4eef9a40b7..76bd29b59a 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -17,7 +17,7 @@ from keras import mixed_precision from tqdm import tqdm -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/references/detection/latency_tensorflow.py b/references/detection/latency_tensorflow.py index e3e0d1d8af..39c0cd6e36 100644 --- a/references/detection/latency_tensorflow.py +++ b/references/detection/latency_tensorflow.py @@ -20,7 +20,7 @@ def main(args): if args.gpu: - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) else: diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index b9c14494ad..5e71909f3d 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -19,7 +19,7 @@ from doctr.models import login_to_hub, push_to_hf_hub -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index 4c9d125285..9fea4f02ed 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -14,7 +14,7 @@ from keras import mixed_precision from tqdm import tqdm -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/references/recognition/latency_tensorflow.py b/references/recognition/latency_tensorflow.py index 405cf56892..318ff03fcb 100644 --- a/references/recognition/latency_tensorflow.py +++ b/references/recognition/latency_tensorflow.py @@ -20,7 +20,7 @@ def main(args): if args.gpu: - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) else: diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index c76355a2f2..ca04cb1200 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -20,7 +20,7 @@ from doctr.models import login_to_hub, push_to_hf_hub -gpu_devices = tf.config.experimental.list_physical_devices("GPU") +gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/scripts/analyze.py b/scripts/analyze.py index 94415267a2..fdffa30e48 100644 --- a/scripts/analyze.py +++ b/scripts/analyze.py @@ -16,7 +16,7 @@ if is_tf_available(): import tensorflow as tf - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/scripts/detect_text.py b/scripts/detect_text.py index f65b6685df..e3ca08c7b0 100644 --- a/scripts/detect_text.py +++ b/scripts/detect_text.py @@ -20,7 +20,7 @@ if is_tf_available(): import tensorflow as tf - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index bc9459b727..86dbc0e561 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -11,6 +11,7 @@ from tqdm import tqdm from doctr import datasets +from doctr import transforms as T from doctr.file_utils import is_tf_available from doctr.models import ocr_predictor from doctr.utils.geometry import extract_crops, extract_rcrops @@ -20,7 +21,7 @@ if is_tf_available(): import tensorflow as tf - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) else: @@ -35,12 +36,24 @@ def main(args): if not args.rotation: args.eval_straight = True + input_shape = (args.size, args.size) + + # We define a transformation function which does transform the annotation + # to the required format for the Resize transformation + def _transform(img, target): + boxes = target["boxes"] + transformed_img, transformed_boxes = T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + )(img, boxes) + return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]} + predictor = ocr_predictor( args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size, - preserve_aspect_ratio=False, + preserve_aspect_ratio=False, # we handle the transformation directly in the dataset so this is set to False + symmetric_pad=False, # we handle the transformation directly in the dataset so this is set to False assume_straight_pages=not args.rotation, ) @@ -48,11 +61,22 @@ def main(args): testset = datasets.OCRDataset( img_folder=args.img_folder, label_file=args.label_file, + sample_transforms=_transform, ) sets = [testset] else: - train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight) - val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight) + train_set = datasets.__dict__[args.dataset]( + train=True, + download=True, + use_polygons=not args.eval_straight, + sample_transforms=_transform, + ) + val_set = datasets.__dict__[args.dataset]( + train=False, + download=True, + use_polygons=not args.eval_straight, + sample_transforms=_transform, + ) sets = [train_set, val_set] reco_metric = TextMatch() @@ -190,6 +214,9 @@ def parse_args(): parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels") parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing") parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition") + parser.add_argument("--size", type=int, default=1024, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples") parser.add_argument( "--eval-straight", diff --git a/scripts/evaluate_kie.py b/scripts/evaluate_kie.py index b3d75d9beb..ca17332e2c 100644 --- a/scripts/evaluate_kie.py +++ b/scripts/evaluate_kie.py @@ -13,6 +13,7 @@ from tqdm import tqdm from doctr import datasets +from doctr import transforms as T from doctr.file_utils import is_tf_available from doctr.models import kie_predictor from doctr.utils.geometry import extract_crops, extract_rcrops @@ -22,7 +23,7 @@ if is_tf_available(): import tensorflow as tf - gpu_devices = tf.config.experimental.list_physical_devices("GPU") + gpu_devices = tf.config.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) else: @@ -37,12 +38,24 @@ def main(args): if not args.rotation: args.eval_straight = True + input_shape = (args.size, args.size) + + # We define a transformation function which does transform the annotation + # to the required format for the Resize transformation + def _transform(img, target): + boxes = target["boxes"] + transformed_img, transformed_boxes = T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + )(img, boxes) + return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]} + predictor = kie_predictor( args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size, - preserve_aspect_ratio=False, + preserve_aspect_ratio=False, # we handle the transformation directly in the dataset so this is set to False + symmetric_pad=False, # we handle the transformation directly in the dataset so this is set to False assume_straight_pages=not args.rotation, ) @@ -50,11 +63,22 @@ def main(args): testset = datasets.OCRDataset( img_folder=args.img_folder, label_file=args.label_file, + sample_transforms=_transform, ) sets = [testset] else: - train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight) - val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight) + train_set = datasets.__dict__[args.dataset]( + train=True, + download=True, + use_polygons=not args.eval_straight, + sample_transforms=_transform, + ) + val_set = datasets.__dict__[args.dataset]( + train=False, + download=True, + use_polygons=not args.eval_straight, + sample_transforms=_transform, + ) sets = [train_set, val_set] reco_metric = TextMatch() @@ -187,6 +211,9 @@ def parse_args(): parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels") parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing") parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition") + parser.add_argument("--size", type=int, default=1024, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples") parser.add_argument( "--eval-straight", diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 2567dd8486..3c11412556 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -66,6 +66,22 @@ def test_resize(): out = transfo(input_t) assert out.dtype == torch.float16 + # --- Test with target (bounding boxes) --- + + target_boxes = np.array([[0.1, 0.1, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]]) + output_size = (64, 64) + + transfo = Resize(output_size, preserve_aspect_ratio=True) + input_t = torch.ones((3, 32, 64), dtype=torch.float32) + out, new_target = transfo(input_t, target_boxes) + + assert out.shape[-2:] == output_size + assert new_target.shape == target_boxes.shape + assert np.all(new_target >= 0) and np.all(new_target <= 1) + + out = transfo(input_t) + assert out.shape[-2:] == output_size + @pytest.mark.parametrize( "rgb_min", diff --git a/tests/tensorflow/test_transforms_tf.py b/tests/tensorflow/test_transforms_tf.py index e53945f2e3..5fa87eab8a 100644 --- a/tests/tensorflow/test_transforms_tf.py +++ b/tests/tensorflow/test_transforms_tf.py @@ -48,6 +48,22 @@ def test_resize(): out = transfo(input_t) assert out.dtype == tf.float16 + # --- Test with target (bounding boxes) --- + + target_boxes = np.array([[0.1, 0.1, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]]) + output_size = (64, 64) + + transfo = T.Resize(output_size, preserve_aspect_ratio=True) + input_t = tf.cast(tf.fill([64, 32, 3], 1), dtype=tf.float32) + out, new_target = transfo(input_t, target_boxes) + + assert out.shape[:2] == output_size + assert new_target.shape == target_boxes.shape + assert np.all(new_target >= 0) and np.all(new_target <= 1) + + out = transfo(input_t) + assert out.shape[:2] == output_size + def test_compose(): output_size = (16, 16) From 90c3fff5d50855aacbd57327a069dc94cb745104 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 7 Oct 2024 11:20:55 +0200 Subject: [PATCH 12/18] [To keep] -- [build] tf upgrade by keeping keras v2 (#1542) --- .../source/using_doctr/using_model_export.rst | 2 +- doctr/file_utils.py | 19 +++++++++++++++++++ doctr/io/image/tensorflow.py | 2 +- .../classification/magc_resnet/tensorflow.py | 4 ++-- .../classification/mobilenet/tensorflow.py | 4 ++-- .../classification/predictor/tensorflow.py | 2 +- .../classification/resnet/tensorflow.py | 6 +++--- .../classification/textnet/tensorflow.py | 2 +- doctr/models/classification/vgg/tensorflow.py | 4 ++-- doctr/models/classification/vit/tensorflow.py | 2 +- .../differentiable_binarization/tensorflow.py | 4 ++-- doctr/models/detection/fast/tensorflow.py | 2 +- doctr/models/detection/linknet/tensorflow.py | 2 +- .../models/detection/predictor/tensorflow.py | 2 +- doctr/models/modules/layers/tensorflow.py | 2 +- .../models/modules/transformer/tensorflow.py | 4 +--- .../modules/vision_transformer/tensorflow.py | 2 +- doctr/models/recognition/crnn/tensorflow.py | 4 ++-- doctr/models/recognition/master/tensorflow.py | 2 +- doctr/models/recognition/parseq/tensorflow.py | 5 +---- doctr/models/recognition/sar/tensorflow.py | 2 +- doctr/models/recognition/vitstr/tensorflow.py | 2 +- doctr/models/utils/tensorflow.py | 6 +++--- doctr/transforms/modules/tensorflow.py | 1 - pyproject.toml | 10 ++++------ references/classification/README.md | 4 ++-- .../classification/latency_tensorflow.py | 8 ++++++-- .../train_tensorflow_character.py | 13 +++++++++++-- .../train_tensorflow_orientation.py | 13 +++++++++++-- references/detection/evaluate_tensorflow.py | 6 +++++- references/detection/latency_tensorflow.py | 9 +++++++-- references/detection/train_tensorflow.py | 13 +++++++++++-- references/recognition/README.md | 2 +- references/recognition/evaluate_tensorflow.py | 6 +++++- references/recognition/latency_tensorflow.py | 8 ++++++-- references/recognition/train_tensorflow.py | 15 ++++++++++++--- .../test_models_classification_tf.py | 10 ++++------ tests/tensorflow/test_models_detection_tf.py | 12 ++++++------ tests/tensorflow/test_models_factory.py | 6 +++--- .../tensorflow/test_models_recognition_tf.py | 5 ++--- tests/tensorflow/test_models_utils_tf.py | 2 +- 41 files changed, 147 insertions(+), 82 deletions(-) diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index 48f570f699..c62c36169b 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -31,7 +31,7 @@ Advantages: .. code:: python3 import tensorflow as tf - from keras import mixed_precision + from tensorflow.keras import mixed_precision mixed_precision.set_global_policy('mixed_float16') predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True) diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 68e9dfffac..fc1129b0c1 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -35,6 +35,20 @@ logging.info("Disabling PyTorch because USE_TF is set") _torch_available = False +# Compatibility fix to make sure tensorflow.keras stays at Keras 2 +if "TF_USE_LEGACY_KERAS" not in os.environ: + os.environ["TF_USE_LEGACY_KERAS"] = "1" + +elif os.environ["TF_USE_LEGACY_KERAS"] != "1": + raise ValueError( + "docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " + ) + + +def ensure_keras_v2() -> None: # pragma: no cover + if not os.environ.get("TF_USE_LEGACY_KERAS") == "1": + os.environ["TF_USE_LEGACY_KERAS"] = "1" + if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: _tf_available = importlib.util.find_spec("tensorflow") is not None @@ -65,6 +79,11 @@ _tf_available = False else: logging.info(f"TensorFlow version {_tf_version} available.") + ensure_keras_v2() + import tensorflow as tf + + # Enable eager execution - this is required for some models to work properly + tf.config.run_functions_eagerly(True) else: # pragma: no cover logging.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py index 3b1f1ed0e2..28fb2fadd5 100644 --- a/doctr/io/image/tensorflow.py +++ b/doctr/io/image/tensorflow.py @@ -7,8 +7,8 @@ import numpy as np import tensorflow as tf -from keras.utils import img_to_array from PIL import Image +from tensorflow.keras.utils import img_to_array from doctr.utils.common_types import AbstractPath diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index 12f7c6beea..fc7678f661 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from keras import activations, layers -from keras.models import Sequential +from tensorflow.keras import activations, layers +from tensorflow.keras.models import Sequential from doctr.datasets import VOCABS diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index 6250abc666..ff57c221dc 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from keras import layers -from keras.models import Sequential +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential from ....datasets import VOCABS from ...utils import conv_sequence, load_pretrained_params diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index ba26e1db54..23efbf6579 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from keras import Model +from tensorflow.keras import Model from doctr.models.preprocessor import PreProcessor from doctr.utils.repr import NestedObject diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 3e78ae0ae2..364b03c3a2 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -7,9 +7,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import tensorflow as tf -from keras import layers -from keras.applications import ResNet50 -from keras.models import Sequential +from tensorflow.keras import layers +from tensorflow.keras.applications import ResNet50 +from tensorflow.keras.models import Sequential from doctr.datasets import VOCABS diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index 3d79b15f09..b0bb9a7205 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from keras import Sequential, layers +from tensorflow.keras import Sequential, layers from doctr.datasets import VOCABS diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index d9e7bb374b..9ecdabd040 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -6,8 +6,8 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from keras import layers -from keras.models import Sequential +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential from doctr.datasets import VOCABS diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 28ff2e244e..8531193939 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import tensorflow as tf -from keras import Sequential, layers +from tensorflow.keras import Sequential, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import EncoderBlock diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 7fdbd43ce0..45e522b872 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -10,8 +10,8 @@ import numpy as np import tensorflow as tf -from keras import Model, Sequential, layers, losses -from keras.applications import ResNet50 +from tensorflow.keras import Model, Sequential, layers, losses +from tensorflow.keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 80fc31fea3..91d6c8cc4d 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -10,7 +10,7 @@ import numpy as np import tensorflow as tf -from keras import Model, Sequential, layers +from tensorflow.keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 683c49373a..df8233cf20 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -10,7 +10,7 @@ import numpy as np import tensorflow as tf -from keras import Model, Sequential, layers, losses +from tensorflow.keras import Model, Sequential, layers, losses from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index a7ccd4a9ac..a3d5085847 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from keras import Model +from tensorflow.keras import Model from doctr.models.detection._utils import _remove_padding from doctr.models.preprocessor import PreProcessor diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py index b1019be778..68849fbf6e 100644 --- a/doctr/models/modules/layers/tensorflow.py +++ b/doctr/models/modules/layers/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from keras import layers +from tensorflow.keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py index eef4f3dbea..50c7cef04d 100644 --- a/doctr/models/modules/transformer/tensorflow.py +++ b/doctr/models/modules/transformer/tensorflow.py @@ -7,14 +7,12 @@ from typing import Any, Callable, Optional, Tuple import tensorflow as tf -from keras import layers +from tensorflow.keras import layers from doctr.utils.repr import NestedObject __all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"] -tf.config.run_functions_eagerly(True) - class PositionalEncoding(layers.Layer, NestedObject): """Compute positional encoding""" diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py index a73aa4c706..8386172eb1 100644 --- a/doctr/models/modules/vision_transformer/tensorflow.py +++ b/doctr/models/modules/vision_transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Tuple import tensorflow as tf -from keras import layers +from tensorflow.keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index d366bfc14b..fb5cb72dff 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from keras import layers -from keras.models import Model, Sequential +from tensorflow.keras import layers +from tensorflow.keras.models import Model, Sequential from doctr.datasets import VOCABS diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 5b8192dee6..42cd216b2c 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from keras import Model, layers +from tensorflow.keras import Model, layers from doctr.datasets import VOCABS from doctr.models.classification import magc_resnet31 diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index bca7806903..b0e21a50d6 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -10,7 +10,7 @@ import numpy as np import tensorflow as tf -from keras import Model, layers +from tensorflow.keras import Model, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward @@ -167,7 +167,6 @@ def __init__( self.postprocessor = PARSeqPostProcessor(vocab=self.vocab) - @tf.function def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor: # Generates permutations of the target sequence. # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py @@ -214,7 +213,6 @@ def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor: ) return combined - @tf.function def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: # Generate source and target mask for the decoder attention. sz = permutation.shape[0] @@ -234,7 +232,6 @@ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple target_mask = mask[1:, :-1] return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool) - @tf.function def decode( self, target: tf.Tensor, diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 0776414c7a..89e93ea51e 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from keras import Model, Sequential, layers +from tensorflow.keras import Model, Sequential, layers from doctr.datasets import VOCABS from doctr.utils.repr import NestedObject diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 985f49a470..6b38cf7548 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from keras import Model, layers +from tensorflow.keras import Model, layers from doctr.datasets import VOCABS diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 51a2bc69a5..6f7dc14ab3 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -8,7 +8,7 @@ import tensorflow as tf import tf2onnx -from keras import Model, layers +from tensorflow.keras import Model, layers from doctr.utils.data import download_from_url @@ -77,7 +77,7 @@ def conv_sequence( ) -> List[layers.Layer]: """Builds a convolutional-based layer sequence - >>> from keras import Sequential + >>> from tensorflow.keras import Sequential >>> from doctr.models import conv_sequence >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) @@ -113,7 +113,7 @@ def conv_sequence( class IntermediateLayerGetter(Model): """Implements an intermediate layer getter - >>> from keras.applications import ResNet50 + >>> from tensorflow.keras.applications import ResNet50 >>> from doctr.models import IntermediateLayerGetter >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) diff --git a/doctr/transforms/modules/tensorflow.py b/doctr/transforms/modules/tensorflow.py index 4b00a9359f..2f2fb25f9c 100644 --- a/doctr/transforms/modules/tensorflow.py +++ b/doctr/transforms/modules/tensorflow.py @@ -399,7 +399,6 @@ def __init__(self, kernel_shape: Union[int, Iterable[int]], std: Tuple[float, fl def extra_repr(self) -> str: return f"kernel_shape={self.kernel_shape}, std={self.std}" - @tf.function def __call__(self, img: tf.Tensor) -> tf.Tensor: return tf.squeeze( _gaussian_filter( diff --git a/pyproject.toml b/pyproject.toml index aa0e02f98e..9745f8a7c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,12 +52,10 @@ dependencies = [ [project.optional-dependencies] tf = [ - # cf. https://github.com/mindee/doctr/pull/1182 # cf. https://github.com/mindee/doctr/pull/1461 - "tensorflow>=2.11.0,<2.16.0", + "tensorflow>=2.15.0,<3.0.0", + "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 - # TODO: This is a temporary fix until we can upgrade to a newer version of tensorflow - "numpy>=1.16.0,<2.0.0", ] torch = [ "torch>=1.12.0,<3.0.0", @@ -98,9 +96,9 @@ docs = [ ] dev = [ # Tensorflow - # cf. https://github.com/mindee/doctr/pull/1182 # cf. https://github.com/mindee/doctr/pull/1461 - "tensorflow>=2.11.0,<2.16.0", + "tensorflow>=2.15.0,<3.0.0", + "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 # PyTorch "torch>=1.12.0,<3.0.0", diff --git a/references/classification/README.md b/references/classification/README.md index d0e5c3b83a..6646b0d8ca 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -60,12 +60,12 @@ Feel free to inspect the multiple script option to customize your training to yo Character classification: -```python +```shell python references/classification/train_tensorflow_character.py --help ``` Orientation classification: -```python +```shell python references/classification/train_tensorflow_orientation.py --help ``` diff --git a/references/classification/latency_tensorflow.py b/references/classification/latency_tensorflow.py index 6ccdefac18..9ed7e16036 100644 --- a/references/classification/latency_tensorflow.py +++ b/references/classification/latency_tensorflow.py @@ -9,12 +9,16 @@ import os import time -import numpy as np -import tensorflow as tf +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" +import numpy as np +import tensorflow as tf + from doctr.models import classification diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index 89d0165d90..d3b6e16a0c 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -5,6 +5,10 @@ import os +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() + os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -13,7 +17,7 @@ import numpy as np import tensorflow as tf -from keras import Model, mixed_precision, optimizers +from tensorflow.keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -82,6 +86,11 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder +@tf.function +def apply_grads(optimizer, grads, model): + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): # Iterate over the batches of the dataset pbar = tqdm(train_loader, position=1) @@ -94,7 +103,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) - optimizer.apply_gradients(zip(grads, model.trainable_weights)) + apply_grads(optimizer, grads, model) pbar.set_description(f"Training loss: {train_loss.numpy().mean():.6}") diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index a7d3b96943..00cfe98add 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -5,6 +5,10 @@ import os +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() + os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -13,7 +17,7 @@ import numpy as np import tensorflow as tf -from keras import Model, mixed_precision, optimizers +from tensorflow.keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -96,6 +100,11 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder +@tf.function +def apply_grads(optimizer, grads, model): + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): # Iterate over the batches of the dataset pbar = tqdm(train_loader, position=1) @@ -108,7 +117,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) - optimizer.apply_gradients(zip(grads, model.trainable_weights)) + apply_grads(optimizer, grads, model) pbar.set_description(f"Training loss: {train_loss.numpy().mean():.6}") diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index 76bd29b59a..c224e07a91 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -5,6 +5,10 @@ import os +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() + from doctr.file_utils import CLASS_NAME os.environ["USE_TF"] = "1" @@ -14,7 +18,7 @@ from pathlib import Path import tensorflow as tf -from keras import mixed_precision +from tensorflow.keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.list_physical_devices("GPU") diff --git a/references/detection/latency_tensorflow.py b/references/detection/latency_tensorflow.py index 39c0cd6e36..17cdf784a5 100644 --- a/references/detection/latency_tensorflow.py +++ b/references/detection/latency_tensorflow.py @@ -9,12 +9,17 @@ import os import time -import numpy as np -import tensorflow as tf +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import numpy as np +import tensorflow as tf + from doctr.models import detection diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 5e71909f3d..0a535cd7cd 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -5,6 +5,10 @@ import os +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() + os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -14,7 +18,7 @@ import numpy as np import tensorflow as tf -from keras import Model, mixed_precision, optimizers +from tensorflow.keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -82,6 +86,11 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder +@tf.function +def apply_grads(optimizer, grads, model): + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): train_iter = iter(train_loader) # Iterate over the batches of the dataset @@ -94,7 +103,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) - optimizer.apply_gradients(zip(grads, model.trainable_weights)) + apply_grads(optimizer, grads, model) pbar.set_description(f"Training loss: {train_loss.numpy():.6}") diff --git a/references/recognition/README.md b/references/recognition/README.md index b82a0d99b5..5823030120 100644 --- a/references/recognition/README.md +++ b/references/recognition/README.md @@ -81,7 +81,7 @@ When typing your labels, be aware that the VOCAB doesn't handle spaces. Also mak Feel free to inspect the multiple script option to customize your training to your own needs! -```python +```shell python references/recognition/train_pytorch.py --help ``` diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index 9fea4f02ed..dc034d333f 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -5,13 +5,17 @@ import os +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() + os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import time import tensorflow as tf -from keras import mixed_precision +from tensorflow.keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.list_physical_devices("GPU") diff --git a/references/recognition/latency_tensorflow.py b/references/recognition/latency_tensorflow.py index 318ff03fcb..26bc2d6bc1 100644 --- a/references/recognition/latency_tensorflow.py +++ b/references/recognition/latency_tensorflow.py @@ -9,12 +9,16 @@ import os import time -import numpy as np -import tensorflow as tf +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" +import numpy as np +import tensorflow as tf + from doctr.models import recognition diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index ca04cb1200..c12752a3e1 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -5,6 +5,10 @@ import os +from doctr.file_utils import ensure_keras_v2 + +ensure_keras_v2() + os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -15,7 +19,7 @@ import numpy as np import tensorflow as tf -from keras import Model, mixed_precision, optimizers +from tensorflow.keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -83,6 +87,11 @@ def record_lr( return lr_recorder[: len(loss_recorder)], loss_recorder +@tf.function +def apply_grads(optimizer, grads, model): + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): train_iter = iter(train_loader) # Iterate over the batches of the dataset @@ -95,7 +104,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) - optimizer.apply_gradients(zip(grads, model.trainable_weights)) + apply_grads(optimizer, grads, model) pbar.set_description(f"Training loss: {train_loss.numpy().mean():.6}") @@ -254,7 +263,7 @@ def main(args): T.RandomSaturation(0.3), T.RandomContrast(0.3), T.RandomBrightness(0.3), - T.RandomApply(T.RandomShadow(), 0.4), + # T.RandomApply(T.RandomShadow(), 0.4), # NOTE: RandomShadow is broken atm T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3), ]), diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 11f4ea4114..89181aace0 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -2,7 +2,6 @@ import tempfile import cv2 -import keras import numpy as np import onnxruntime import psutil @@ -38,7 +37,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): # Model batch_size = 2 - keras.backend.clear_session() + tf.keras.backend.clear_session() model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) # Forward out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32)) @@ -47,7 +46,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): assert out.dtype == tf.float32 assert out.numpy().shape == (batch_size, *output_size) # Check that you can load pretrained up to the classification layer with differing number of classes to fine-tune - keras.backend.clear_session() + tf.keras.backend.clear_session() assert classification.__dict__[arch_name]( pretrained=True, include_top=True, input_shape=input_shape, num_classes=10 ) @@ -63,7 +62,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): def test_classification_models(arch_name, input_shape): batch_size = 8 reco_model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, keras.Model) + assert isinstance(reco_model, tf.keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) out = reco_model(input_tensor) @@ -232,7 +231,7 @@ def test_page_orientation_model(mock_payslip): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - keras.backend.clear_session() + tf.keras.backend.clear_session() if "orientation" in arch_name: model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) else: @@ -252,7 +251,6 @@ def test_models_onnx_export(arch_name, input_shape, output_size): model_path, output = export_model_to_onnx( model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input ) - assert os.path.exists(model_path) # Inference ort_session = onnxruntime.InferenceSession( diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index ba5f50542b..7dbb090bf2 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -2,7 +2,6 @@ import os import tempfile -import keras import numpy as np import onnxruntime import psutil @@ -38,13 +37,13 @@ ) def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): batch_size = 2 - keras.backend.clear_session() + tf.keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, input_shape=input_shape)) train_mode = False # Reparameterized model is not trainable else: model = detection.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(model, keras.Model) + assert isinstance(model, tf.keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = [ {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, @@ -153,7 +152,7 @@ def test_rotated_detectionpredictor(mock_pdf): ) def test_detection_zoo(arch_name): # Model - keras.backend.clear_session() + tf.keras.backend.clear_session() predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) # object check assert isinstance(predictor, DetectionPredictor) @@ -178,7 +177,7 @@ def test_fast_reparameterization(): base_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(base_model_params, 13535296) # base model params base_out = base_model(dummy_input, training=False)["logits"] - keras.backend.clear_session() + tf.keras.backend.clear_session() rep_model = reparameterize(base_model) rep_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(rep_model_params, 8520256) # reparameterized model params @@ -242,7 +241,7 @@ def test_dilate(): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - keras.backend.clear_session() + tf.keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True, input_shape=input_shape)) else: @@ -257,6 +256,7 @@ def test_models_onnx_export(arch_name, input_shape, output_size): model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input ) assert os.path.exists(model_path) + # Inference ort_session = onnxruntime.InferenceSession( os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py index 0860d8612c..a4483800c9 100644 --- a/tests/tensorflow/test_models_factory.py +++ b/tests/tensorflow/test_models_factory.py @@ -2,8 +2,8 @@ import os import tempfile -import keras import pytest +import tensorflow as tf from doctr import models from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub @@ -50,7 +50,7 @@ def test_push_to_hf_hub(): ) def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): with tempfile.TemporaryDirectory() as tmp_dir: - keras.backend.clear_session() + tf.keras.backend.clear_session() model = models.__dict__[task_name].__dict__[arch_name](pretrained=True) _save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir) @@ -65,6 +65,6 @@ def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): assert all(key in model.cfg.keys() for key in tmp_config.keys()) # test from hub - keras.backend.clear_session() + tf.keras.backend.clear_session() hub_model = from_hub(repo_id=dummy_model_id) assert isinstance(hub_model, type(model)) diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py index 7da1cb534a..162c446d35 100644 --- a/tests/tensorflow/test_models_recognition_tf.py +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -2,7 +2,6 @@ import shutil import tempfile -import keras import numpy as np import onnxruntime import psutil @@ -41,7 +40,7 @@ def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): batch_size = 4 reco_model = recognition.__dict__[arch_name](vocab=mock_vocab, pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, keras.Model) + assert isinstance(reco_model, tf.keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = ["i", "am", "a", "jedi"] @@ -195,7 +194,7 @@ def test_recognition_zoo_error(): def test_models_onnx_export(arch_name, input_shape): # Model batch_size = 2 - keras.backend.clear_session() + tf.keras.backend.clear_session() model = recognition.__dict__[arch_name](pretrained=True, exportable=True, input_shape=input_shape) # SAR, MASTER, ViTSTR export currently only available with constant batch size if arch_name in ["sar_resnet31", "master", "vitstr_small", "parseq"]: diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py index b57b41b14b..4783a09b40 100644 --- a/tests/tensorflow/test_models_utils_tf.py +++ b/tests/tensorflow/test_models_utils_tf.py @@ -2,7 +2,7 @@ import pytest import tensorflow as tf -from keras.applications import ResNet50 +from tensorflow.keras.applications import ResNet50 from doctr.models.classification import mobilenet_v3_small from doctr.models.utils import ( From 59f1c30cd271e81d1fa21c6598c2876b048deace Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 8 Oct 2024 11:03:02 +0200 Subject: [PATCH 13/18] [demo] Automate doctr demo update via CI job (#1742) --- .github/workflows/demo.yml | 54 +++++++++++++++++++++++++++++++++++--- demo/README.md | 47 +++++++++++++++++++++++++++++++++ demo/packages.txt | 1 + 3 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 demo/README.md create mode 100644 demo/packages.txt diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml index 0fc7f203ff..ad5a1045b3 100644 --- a/.github/workflows/demo.yml +++ b/.github/workflows/demo.yml @@ -1,10 +1,23 @@ name: demo on: - push: - branches: main + # Run 'test-demo' on every pull request to the main branch pull_request: - branches: main + branches: [main] + + # Run 'test-demo' on every push to the main branch or both jobs when a new version tag is pushed + push: + branches: + - main + tags: + - 'v*' + + # Run 'sync-to-hub' on a scheduled cron job + schedule: + - cron: '0 2 10 * *' # At 02:00 on day-of-month 10 (every month) + + # Allow manual triggering of the workflow + workflow_dispatch: jobs: test-demo: @@ -69,3 +82,38 @@ jobs: screen -dm streamlit run demo/app.py sleep 10 curl http://localhost:8501/docs + + # This job only runs when a new version tag is pushed or during the cron job or when manually triggered + sync-to-hub: + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' + needs: test-demo + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.9"] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Install huggingface_hub + run: pip install huggingface-hub + - name: Upload folder to Hugging Face + # Only keep the requirements.txt file for the demo (PyTorch) + run: | + mv demo/pt-requirements.txt demo/requirements.txt + rm demo/tf-requirements.txt + + python -c " + from huggingface_hub import HfApi + api = HfApi(token='${{ secrets.HF_TOKEN }}') + repo_id = 'mindee/doctr' + api.upload_folder(repo_id=repo_id, repo_type='space', folder_path='demo/') + api.restart_space(repo_id=repo_id, factory_reboot=True) + " diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 0000000000..ec653d3068 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,47 @@ +--- +title: docTR +emoji: 📑 +colorFrom: purple +colorTo: pink +sdk: streamlit +sdk_version: 1.39.0 +app_file: app.py +pinned: false +license: apache-2.0 +--- + +## Configuration + +`title`: _string_ +Display title for the Space + +`emoji`: _string_ +Space emoji (emoji-only character allowed) + +`colorFrom`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`colorTo`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`sdk`: _string_ +Can be either `gradio` or `streamlit` + +`sdk_version` : _string_ +Only applicable for `streamlit` SDK. +See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. + +`app_file`: _string_ +Path to your main application file (which contains either `gradio` or `streamlit` Python code). +Path is relative to the root of the repository. + +`pinned`: _boolean_ +Whether the Space stays on top of your list. + +## Run the demo locally + +```bash +cd demo +pip install -r pt-requirements.txt +streamlit run app.py +``` diff --git a/demo/packages.txt b/demo/packages.txt new file mode 100644 index 0000000000..d0f1245c6f --- /dev/null +++ b/demo/packages.txt @@ -0,0 +1 @@ +python3-opencv From 2f9f50e5a7ec5ea4045fd4baff5aa93c9d2ed906 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Wed, 9 Oct 2024 17:15:13 +0200 Subject: [PATCH 14/18] [TF] Move model building & unify train scripts (#1744) --- .github/workflows/references.yml | 12 ++++++------ .../classification/magc_resnet/tensorflow.py | 6 ++++-- .../classification/mobilenet/tensorflow.py | 4 +++- doctr/models/classification/resnet/tensorflow.py | 5 ++++- .../models/classification/textnet/tensorflow.py | 4 +++- doctr/models/classification/vgg/tensorflow.py | 4 +++- doctr/models/classification/vit/tensorflow.py | 4 +++- .../differentiable_binarization/tensorflow.py | 11 ++++++++++- doctr/models/detection/fast/tensorflow.py | 7 +++---- doctr/models/detection/linknet/tensorflow.py | 14 +++++++++++--- doctr/models/factory/hub.py | 6 ------ doctr/models/preprocessor/tensorflow.py | 2 +- doctr/models/recognition/crnn/tensorflow.py | 3 ++- doctr/models/recognition/master/tensorflow.py | 4 +++- doctr/models/recognition/parseq/tensorflow.py | 4 +++- doctr/models/recognition/sar/tensorflow.py | 3 ++- doctr/models/recognition/vitstr/tensorflow.py | 4 +++- doctr/models/utils/tensorflow.py | 16 +++++++++++----- references/classification/README.md | 4 ++-- .../classification/train_pytorch_orientation.py | 6 +++--- .../classification/train_tensorflow_character.py | 2 -- .../train_tensorflow_orientation.py | 10 ++++------ references/detection/README.md | 4 ++-- references/detection/evaluate_tensorflow.py | 2 +- references/detection/train_pytorch.py | 4 ++-- references/detection/train_tensorflow.py | 12 ++---------- references/detection/utils.py | 8 -------- references/recognition/README.md | 2 +- references/recognition/evaluate_tensorflow.py | 2 +- references/recognition/train_tensorflow.py | 2 -- 30 files changed, 93 insertions(+), 78 deletions(-) diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index 56856ba1d3..f79784244a 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -114,16 +114,16 @@ jobs: unzip toy_recogition_set-036a4d80.zip -d reco_set - if: matrix.framework == 'tensorflow' name: Train for a short epoch (TF) (document orientation) - run: python references/classification/train_tensorflow_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1 + run: python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 - if: matrix.framework == 'pytorch' name: Train for a short epoch (PT) (document orientation) - run: python references/classification/train_pytorch_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1 + run: python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 - if: matrix.framework == 'tensorflow' name: Train for a short epoch (TF) (crop orientation) - run: python references/classification/train_tensorflow_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1 + run: python references/classification/train_tensorflow_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1 - if: matrix.framework == 'pytorch' name: Train for a short epoch (PT) (crop orientation) - run: python references/classification/train_pytorch_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1 + run: python references/classification/train_pytorch_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1 train-text-recognition: runs-on: ${{ matrix.os }} @@ -318,10 +318,10 @@ jobs: unzip toy_detection_set-bbbb4243.zip -d det_set - if: matrix.framework == 'tensorflow' name: Train for a short epoch (TF) - run: python references/detection/train_tensorflow.py --train_path ./det_set --val_path ./det_set linknet_resnet18 -b 2 --epochs 1 + run: python references/detection/train_tensorflow.py linknet_resnet18 --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 - if: matrix.framework == 'pytorch' name: Train for a short epoch (PT) - run: python references/detection/train_pytorch.py ./det_set ./det_set db_mobilenet_v3_large -b 2 --epochs 1 + run: python references/detection/train_pytorch.py db_mobilenet_v3_large --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 evaluate-text-detection: runs-on: ${{ matrix.os }} diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index fc7678f661..d920ca44a4 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -14,7 +14,7 @@ from doctr.datasets import VOCABS -from ...utils import load_pretrained_params +from ...utils import _build_model, load_pretrained_params from ..resnet.tensorflow import ResNet __all__ = ["magc_resnet31"] @@ -115,7 +115,7 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: # Context modeling: B, H, W, C -> B, 1, 1, C context = self.context_modeling(inputs) # Transform: B, 1, 1, C -> B, 1, 1, C - transformed = self.transform(context) + transformed = self.transform(context, **kwargs) return inputs + transformed @@ -152,6 +152,8 @@ def _magc_resnet( cfg=_cfg, **kwargs, ) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index ff57c221dc..ae3535d947 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -13,7 +13,7 @@ from tensorflow.keras.models import Sequential from ....datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = [ "MobileNetV3", @@ -295,6 +295,8 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa cfg=_cfg, **kwargs, ) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 364b03c3a2..662a43c3a0 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"] @@ -210,6 +210,8 @@ def _resnet( model = ResNet( num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs ) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => @@ -358,6 +360,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: ) model.cfg = _cfg + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index b0bb9a7205..e5e6105a7e 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...modules.layers.tensorflow import FASTConvLayer -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["textnet_tiny", "textnet_small", "textnet_base"] @@ -111,6 +111,8 @@ def _textnet( # Build the model model = TextNet(cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index 9ecdabd040..c42e369bcd 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -11,7 +11,7 @@ from doctr.datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["VGG", "vgg16_bn_r"] @@ -81,6 +81,8 @@ def _vgg( # Build the model model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 8531193939..386065bca6 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -14,7 +14,7 @@ from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding from doctr.utils.repr import NestedObject -from ...utils import load_pretrained_params +from ...utils import _build_model, load_pretrained_params __all__ = ["vit_s", "vit_b"] @@ -121,6 +121,8 @@ def _vit( # Build the model model = VisionTransformer(cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 45e522b872..b0ca1f08e5 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -14,7 +14,13 @@ from tensorflow.keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_to_float32, + _build_model, + conv_sequence, + load_pretrained_params, +) from doctr.utils.repr import NestedObject from ...classification import mobilenet_v3_large @@ -304,6 +310,8 @@ def _db_resnet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning @@ -347,6 +355,7 @@ def _db_mobilenet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 91d6c8cc4d..b0043494ed 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -13,7 +13,7 @@ from tensorflow.keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params from doctr.utils.repr import NestedObject from ...classification import textnet_base, textnet_small, textnet_tiny @@ -333,6 +333,8 @@ def _fast( # Build the model model = FAST(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning @@ -342,9 +344,6 @@ def _fast( skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), ) - # Build the model for reparameterization to access the layers - _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False) - return model diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index df8233cf20..9c991c6f4c 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -14,7 +14,13 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_to_float32, + _build_model, + conv_sequence, + load_pretrained_params, +) from doctr.utils.repr import NestedObject from .base import LinkNetPostProcessor, _LinkNet @@ -79,10 +85,10 @@ def __init__( for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1]) ] - def call(self, x: List[tf.Tensor]) -> tf.Tensor: + def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor: out = 0 for decoder, fmap in zip(self.decoders, x[::-1]): - out = decoder(out + fmap) + out = decoder(out + fmap, **kwargs) return out def extra_repr(self) -> str: @@ -274,6 +280,8 @@ def _linknet( # Build the model model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index b5844dd30b..dd9fc5d776 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -27,8 +27,6 @@ if is_torch_available(): import torch -elif is_tf_available(): - import tensorflow as tf __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] @@ -76,8 +74,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task torch.save(model.state_dict(), weights_path) elif is_tf_available(): weights_path = save_directory / "tf_model.weights.h5" - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.save_weights(str(weights_path)) config_path = save_directory / "config.json" @@ -229,8 +225,6 @@ def from_hub(repo_id: str, **kwargs: Any): model.load_state_dict(state_dict) else: # tf weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs) - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.load_weights(weights) return model diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py index 5a211004f3..85e06fca3e 100644 --- a/doctr/models/preprocessor/tensorflow.py +++ b/doctr/models/preprocessor/tensorflow.py @@ -41,7 +41,7 @@ def __init__( self.resize = Resize(output_size, **kwargs) # Perform the division by 255 at the same time self.normalize = Normalize(mean, std) - self._runs_on_cuda = tf.test.is_gpu_available() + self._runs_on_cuda = tf.config.list_physical_devices("GPU") != [] def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]: """Gather samples into batches for inference purposes diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index fb5cb72dff..9f74882673 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -245,6 +245,7 @@ def _crnn( # Build the model model = CRNN(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 42cd216b2c..e01c089012 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -13,7 +13,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -290,6 +290,8 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool cfg=_cfg, **kwargs, ) + _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index b0e21a50d6..d8c54527be 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -16,7 +16,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -473,6 +473,8 @@ def _parseq( # Build the model model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 89e93ea51e..bcb0b207ef 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -13,7 +13,7 @@ from doctr.utils.repr import NestedObject from ...classification import resnet31 -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -392,6 +392,7 @@ def _sar( # Build the model model = SAR(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 6b38cf7548..9b121171f8 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -216,6 +216,8 @@ def _vitstr( # Build the model model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 6f7dc14ab3..c04a4b2893 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -17,6 +17,7 @@ __all__ = [ "load_pretrained_params", + "_build_model", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", @@ -34,6 +35,16 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor: return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x +def _build_model(model: Model): + """Build a model by calling it once with dummy input + + Args: + ---- + model: the model to be built + """ + model(tf.zeros((1, *model.cfg["input_shape"])), training=False) + + def load_pretrained_params( model: Model, url: Optional[str] = None, @@ -58,11 +69,6 @@ def load_pretrained_params( logging.warning("Invalid model URL, using default initialization.") else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) - - # Build the model - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) - # Load weights model.load_weights(archive_path, skip_mismatch=skip_mismatch) diff --git a/references/classification/README.md b/references/classification/README.md index 6646b0d8ca..885cc0b565 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -30,13 +30,13 @@ python references/classification/train_pytorch_character.py mobilenet_v3_large - You can start your training in TensorFlow: ```shell -python references/classification/train_tensorflow_orientation.py path/to/your/train_set path/to/your/val_set resnet18 page --epochs 5 +python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` or PyTorch: ```shell -python references/classification/train_pytorch_orientation.py path/to/your/train_set path/to/your/val_set resnet18 page --epochs 5 +python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` The type can be either `page` for document images or `crop` for word crops. diff --git a/references/classification/train_pytorch_orientation.py b/references/classification/train_pytorch_orientation.py index 46c77d4c38..8324f0aa37 100644 --- a/references/classification/train_pytorch_orientation.py +++ b/references/classification/train_pytorch_orientation.py @@ -375,10 +375,10 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("train_path", type=str, help="path to training data folder") - parser.add_argument("val_path", type=str, help="path to validation data folder") parser.add_argument("arch", type=str, help="classification model to train") - parser.add_argument("type", type=str, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index d3b6e16a0c..0b1b648d93 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -185,8 +185,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 00cfe98add..297a5674f4 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -196,8 +196,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, *input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ @@ -340,7 +338,7 @@ def main(args): if args.export_onnx: print("Exporting model to ONNX...") - if args.arch == "vit_b": + if args.arch in ["vit_s", "vit_b"]: # fixed batch size for vit dummy_input = [tf.TensorSpec([1, *(input_size), 3], tf.float32, name="input")] else: @@ -358,10 +356,10 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("train_path", type=str, help="path to training data folder") - parser.add_argument("val_path", type=str, help="path to validation data folder") parser.add_argument("arch", type=str, help="classification model to train") - parser.add_argument("type", type=str, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--train_path", type=str, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") diff --git a/references/detection/README.md b/references/detection/README.md index 7a07b4cb6b..35d1481877 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -16,13 +16,13 @@ pip install -r references/requirements.txt You can start your training in TensorFlow: ```shell -python references/detection/train_tensorflow.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 +python references/detection/train_tensorflow.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` or PyTorch: ```shell -python references/detection/train_pytorch.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 --device 0 +python references/detection/train_pytorch.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` ## Data format diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index c224e07a91..a2c5bbe49c 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -40,7 +40,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): for images, targets in tqdm(val_loader): images = batch_transforms(images) targets = [{CLASS_NAME: t} for t in targets] - out = model(images, targets, training=False, return_preds=True) + out = model(images, target=targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] for target, loc_pred in zip(targets, loc_preds): diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 0c30925146..091d257898 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -427,9 +427,9 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("train_path", type=str, help="path to training data folder") - parser.add_argument("val_path", type=str, help="path to validation data folder") parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 0a535cd7cd..f054879e8f 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -31,7 +31,7 @@ from doctr.datasets import DataLoader, DetectionDataset from doctr.models import detection from doctr.utils.metrics import LocalizationConfusion -from utils import EarlyStopper, load_backbone, plot_recorder, plot_samples +from utils import EarlyStopper, plot_recorder, plot_samples def record_lr( @@ -193,15 +193,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) - if isinstance(args.pretrained_backbone, str): - print("Loading backbone weights.") - model = load_backbone(model, args.pretrained_backbone) - print("Done.") - # Metrics val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight) @@ -411,7 +404,7 @@ def parse_args(): parser.add_argument("arch", type=str, help="text-detection model to train") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") - parser.add_argument("--val_path", type=str, help="path to validation data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") @@ -421,7 +414,6 @@ def parse_args(): parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") - parser.add_argument("--pretrained-backbone", type=str, default=None, help="Path to your backbone weights") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") parser.add_argument( "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" diff --git a/references/detection/utils.py b/references/detection/utils.py index 7983ee4d51..1a84f2340d 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -3,7 +3,6 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -import pickle from typing import Dict, List import cv2 @@ -86,13 +85,6 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N plt.show(**kwargs) -def load_backbone(model, weights_path): - pretrained_backbone_weights = pickle.load(open(weights_path, "rb")) - model.feat_extractor.set_weights(pretrained_backbone_weights[0]) - model.fpn.set_weights(pretrained_backbone_weights[1]) - return model - - class EarlyStopper: def __init__(self, patience: int = 5, min_delta: float = 0.01): self.patience = patience diff --git a/references/recognition/README.md b/references/recognition/README.md index 5823030120..9087cbc210 100644 --- a/references/recognition/README.md +++ b/references/recognition/README.md @@ -22,7 +22,7 @@ python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path pat or PyTorch: ```shell -python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --device 0 +python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` ### Multi-GPU support (PyTorch only - Experimental) diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index dc034d333f..b6ca50b516 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): for images, targets in tqdm(val_iter): try: images = batch_transforms(images) - out = model(images, targets, return_preds=True, training=False) + out = model(images, target=targets, return_preds=True, training=False) # Compute metric if len(out["preds"]): words, _ = zip(*out["preds"]) diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index c12752a3e1..348f3a3869 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -193,8 +193,6 @@ def main(args): ) # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, 4 * args.input_size, 3)), training=False) model.load_weights(args.resume) # Metrics From 0ca8249b630dc0f67bf89c477d45d665f84135c0 Mon Sep 17 00:00:00 2001 From: T2K-Felix <125863421+felixT2K@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:12:22 +0200 Subject: [PATCH 15/18] [demo] Add missing viz dep for demo (#1751) --- demo/pt-requirements.txt | 2 +- demo/tf-requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/demo/pt-requirements.txt b/demo/pt-requirements.txt index 639f79c607..5b098487ac 100644 --- a/demo/pt-requirements.txt +++ b/demo/pt-requirements.txt @@ -1,2 +1,2 @@ --e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch] +-e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch,viz] streamlit>=1.0.0 diff --git a/demo/tf-requirements.txt b/demo/tf-requirements.txt index f16f3aa066..2fd4b08a64 100644 --- a/demo/tf-requirements.txt +++ b/demo/tf-requirements.txt @@ -1,2 +1,2 @@ --e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf] +-e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf,viz] streamlit>=1.0.0 From 7f21e76175cdb47c430b9c5a031a09686be94fb6 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Thu, 10 Oct 2024 19:02:18 +0200 Subject: [PATCH 16/18] [Build] update minor version & update torch to >= 2.0 (#1747) --- .conda/meta.yaml | 2 +- README.md | 2 +- docs/source/getting_started/installing.rst | 2 +- pyproject.toml | 8 ++++---- setup.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 7feb3a1bf9..f377481f16 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -1,7 +1,7 @@ {% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %} {% set project = pyproject.get('project') %} {% set urls = pyproject.get('project', {}).get('urls') %} -{% set version = environ.get('BUILD_VERSION', '0.9.1a0') %} +{% set version = environ.get('BUILD_VERSION', '0.10.0a0') %} package: name: {{ project.get('name') }} diff --git a/README.md b/README.md index 2fc92971ff..d57228fda6 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ pip install "python-doctr[torch,viz,html,contib]" For MacBooks with M1 chip, you will need some additional packages or specific versions: - TensorFlow 2: [metal plugin](https://developer.apple.com/metal/tensorflow-plugin/) -- PyTorch: [version >= 1.12.0](https://pytorch.org/get-started/locally/#start-locally) +- PyTorch: [version >= 2.0.0](https://pytorch.org/get-started/locally/#start-locally) ### Developer mode diff --git a/docs/source/getting_started/installing.rst b/docs/source/getting_started/installing.rst index 46d4177b30..e764e734a7 100644 --- a/docs/source/getting_started/installing.rst +++ b/docs/source/getting_started/installing.rst @@ -17,7 +17,7 @@ Whichever OS you are running, you will need to install at least TensorFlow or Py For MacBooks with M1 chip, you will need some additional packages or specific versions: * `TensorFlow 2 Metal Plugin `_ -* `PyTorch >= 1.12.0 `_ +* `PyTorch >= 2.0.0 `_ Via Python Package ================== diff --git a/pyproject.toml b/pyproject.toml index 9745f8a7c4..613eb512e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,8 +58,8 @@ tf = [ "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 ] torch = [ - "torch>=1.12.0,<3.0.0", - "torchvision>=0.13.0", + "torch>=2.0.0,<3.0.0", + "torchvision>=0.15.0", "onnx>=1.12.0,<3.0.0", ] html = [ @@ -101,8 +101,8 @@ dev = [ "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 # PyTorch - "torch>=1.12.0,<3.0.0", - "torchvision>=0.13.0", + "torch>=2.0.0,<3.0.0", + "torchvision>=0.15.0", "onnx>=1.12.0,<3.0.0", # Extras "weasyprint>=55.0", diff --git a/setup.py b/setup.py index f45f3f157d..13fd4515e3 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import setup PKG_NAME = "python-doctr" -VERSION = os.getenv("BUILD_VERSION", "0.9.1a0") +VERSION = os.getenv("BUILD_VERSION", "0.10.0a0") if __name__ == "__main__": From 2ca3928642f6b289ba2adb95bfada138db62d3f1 Mon Sep 17 00:00:00 2001 From: T2K-Felix <125863421+felixT2K@users.noreply.github.com> Date: Sun, 20 Oct 2024 17:09:17 +0200 Subject: [PATCH 17/18] [demo/docs] Update notebook docs & minor demo update / fix (#1755) --- demo/app.py | 24 ++++++++++++------------ demo/backend/pytorch.py | 4 +++- demo/backend/tensorflow.py | 4 +++- demo/packages.txt | 1 + doctr/io/elements.py | 4 ++-- doctr/models/utils/pytorch.py | 2 +- notebooks/README.rst | 27 +++++++++++++++------------ 7 files changed, 37 insertions(+), 29 deletions(-) diff --git a/demo/app.py b/demo/app.py index 60adba0fb8..d43c8eff91 100644 --- a/demo/app.py +++ b/demo/app.py @@ -71,15 +71,14 @@ def main(det_archs, reco_archs): # Only straight pages or possible rotation st.sidebar.title("Parameters") assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True) - st.sidebar.write("\n") # Disable page orientation detection disable_page_orientation = st.sidebar.checkbox("Disable page orientation detection", value=False) - st.sidebar.write("\n") # Disable crop orientation detection disable_crop_orientation = st.sidebar.checkbox("Disable crop orientation detection", value=False) - st.sidebar.write("\n") # Straighten pages straighten_pages = st.sidebar.checkbox("Straighten pages", value=False) + # Export as straight boxes + export_straight_boxes = st.sidebar.checkbox("Export as straight boxes", value=False) st.sidebar.write("\n") # Binarization threshold bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1) @@ -95,15 +94,16 @@ def main(det_archs, reco_archs): else: with st.spinner("Loading model..."): predictor = load_predictor( - det_arch, - reco_arch, - assume_straight_pages, - straighten_pages, - disable_page_orientation, - disable_crop_orientation, - bin_thresh, - box_thresh, - forward_device, + det_arch=det_arch, + reco_arch=reco_arch, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + export_as_straight_boxes=export_straight_boxes, + disable_page_orientation=disable_page_orientation, + disable_crop_orientation=disable_crop_orientation, + bin_thresh=bin_thresh, + box_thresh=box_thresh, + device=forward_device, ) with st.spinner("Analyzing..."): diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index e3ced74d5f..548d696dde 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -37,6 +37,7 @@ def load_predictor( reco_arch: str, assume_straight_pages: bool, straighten_pages: bool, + export_as_straight_boxes: bool, disable_page_orientation: bool, disable_crop_orientation: bool, bin_thresh: float, @@ -51,6 +52,7 @@ def load_predictor( reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not + export_as_straight_boxes: whether to export boxes as straight or not disable_page_orientation: whether to disable page orientation or not disable_crop_orientation: whether to disable crop orientation or not bin_thresh: binarization threshold for the segmentation map @@ -67,7 +69,7 @@ def load_predictor( pretrained=True, assume_straight_pages=assume_straight_pages, straighten_pages=straighten_pages, - export_as_straight_boxes=straighten_pages, + export_as_straight_boxes=export_as_straight_boxes, detect_orientation=not assume_straight_pages, disable_page_orientation=disable_page_orientation, disable_crop_orientation=disable_crop_orientation, diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index 6ca9614159..9fecfce3bc 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -36,6 +36,7 @@ def load_predictor( reco_arch: str, assume_straight_pages: bool, straighten_pages: bool, + export_as_straight_boxes: bool, disable_page_orientation: bool, disable_crop_orientation: bool, bin_thresh: float, @@ -50,6 +51,7 @@ def load_predictor( reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not + export_as_straight_boxes: whether to export boxes as straight or not disable_page_orientation: whether to disable page orientation or not disable_crop_orientation: whether to disable crop orientation or not bin_thresh: binarization threshold for the segmentation map @@ -67,7 +69,7 @@ def load_predictor( pretrained=True, assume_straight_pages=assume_straight_pages, straighten_pages=straighten_pages, - export_as_straight_boxes=straighten_pages, + export_as_straight_boxes=export_as_straight_boxes, detect_orientation=not assume_straight_pages, disable_page_orientation=disable_page_orientation, disable_crop_orientation=disable_crop_orientation, diff --git a/demo/packages.txt b/demo/packages.txt index d0f1245c6f..c0e46c2d27 100644 --- a/demo/packages.txt +++ b/demo/packages.txt @@ -1 +1,2 @@ python3-opencv +fonts-freefont-ttf diff --git a/doctr/io/elements.py b/doctr/io/elements.py index b27ecb35eb..55e1d5e84a 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -168,7 +168,7 @@ def __init__( if geometry is None: # Check whether this is a rotated or straight box box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox - geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator] + geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[misc] super().__init__(words=words) self.geometry = geometry @@ -232,7 +232,7 @@ def __init__( box_resolution_fn = ( resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox ) - geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator] + geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore super().__init__(lines=lines, artefacts=artefacts) self.geometry = geometry diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index 0401cdef6c..998ccb7cf1 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -157,7 +157,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T """ torch.onnx.export( model, - dummy_input, + dummy_input, # type: ignore[arg-type] f"{model_name}.onnx", input_names=["input"], output_names=["logits"], diff --git a/notebooks/README.rst b/notebooks/README.rst index 96f9e80edb..940e43f7b7 100644 --- a/notebooks/README.rst +++ b/notebooks/README.rst @@ -3,15 +3,18 @@ docTR Notebooks Here are some notebooks compiled for users to better leverage the library capabilities: -+--------------------------------------------------------------------------------------------------------+----------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ -| Notebook | Description | Colab | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ -| `[Quicktour] `_ | A presentation of the main features of docTR | .. image:: https://colab.research.google.com/assets/colab-badge.svg | -| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ -| `[Export as PDF/A] `_ | Produce searchable PDFs from docTR results | .. image:: https://colab.research.google.com/assets/colab-badge.svg | -| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/export_as_pdfa.ipynb | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ -| `[Artefact detection] `_ | Object detection for artefacts in documents | .. image:: https://colab.research.google.com/assets/colab-badge.svg | -| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/artefact_detection.ipynb | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ ++--------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+ +| Notebook | Description | Colab | ++--------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+ +| `[Quicktour] `_ | A presentation of the main features of docTR | .. image:: https://colab.research.google.com/assets/colab-badge.svg | +| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb | ++--------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+ +| `[Export as PDF/A] `_ | Produce searchable PDFs from docTR results | .. image:: https://colab.research.google.com/assets/colab-badge.svg | +| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/export_as_pdfa.ipynb | ++--------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+ +| `[Using standalone predictors] `_ | Showcase how to use detection, recognition, and orientation predictors| .. image:: https://colab.research.google.com/assets/colab-badge.svg | +| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/using_standalone_predictors.ipynb | ++--------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+ +| `[Dealing with rotated documents] `_ | A presentation on how to handle documents containing rotations | .. image:: https://colab.research.google.com/assets/colab-badge.svg | +| | | :target: https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/dealing_with_rotations.ipynb | ++--------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------+ From d5dbc73eac6a9eaae487265702e3c837c570d512 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 21 Oct 2024 10:34:15 +0200 Subject: [PATCH 18/18] [Reconstitution] Improve reconstitution (#1750) --- doctr/io/elements.py | 12 +- doctr/utils/reconstitution.py | 216 +++++++++++++++------- tests/common/test_utils_reconstitution.py | 42 ++++- 3 files changed, 198 insertions(+), 72 deletions(-) diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 55e1d5e84a..324d70f0b4 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -310,6 +310,10 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, ** def synthesize(self, **kwargs) -> np.ndarray: """Synthesize the page from the predictions + Args: + ---- + **kwargs: keyword arguments passed to the `synthesize_page` method + Returns ------- synthesized page @@ -493,7 +497,7 @@ def synthesize(self, **kwargs) -> np.ndarray: Args: ---- - **kwargs: keyword arguments passed to the matplotlib.pyplot.show method + **kwargs: keyword arguments passed to the `synthesize_kie_page` method Returns: ------- @@ -603,11 +607,15 @@ def show(self, **kwargs) -> None: def synthesize(self, **kwargs) -> List[np.ndarray]: """Synthesize all pages from their predictions + Args: + ---- + **kwargs: keyword arguments passed to the `Page.synthesize` method + Returns ------- list of synthesized pages """ - return [page.synthesize() for page in self.pages] + return [page.synthesize(**kwargs) for page in self.pages] def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]: """Export the document as XML (hOCR-format) diff --git a/doctr/utils/reconstitution.py b/doctr/utils/reconstitution.py index 82ae20cdd0..a229e9ddbc 100644 --- a/doctr/utils/reconstitution.py +++ b/doctr/utils/reconstitution.py @@ -2,6 +2,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import logging from typing import Any, Dict, Optional import numpy as np @@ -13,10 +14,109 @@ __all__ = ["synthesize_page", "synthesize_kie_page"] +# Global variable to avoid multiple warnings +ROTATION_WARNING = False + + +def _warn_rotation(entry: Dict[str, Any]) -> None: # pragma: no cover + global ROTATION_WARNING + if not ROTATION_WARNING and len(entry["geometry"]) == 4: + logging.warning("Polygons with larger rotations will lead to inaccurate rendering") + ROTATION_WARNING = True + + +def _synthesize( + response: Image.Image, + entry: Dict[str, Any], + w: int, + h: int, + draw_proba: bool = False, + font_family: Optional[str] = None, + smoothing_factor: float = 0.75, + min_font_size: int = 6, + max_font_size: int = 50, +) -> Image.Image: + if len(entry["geometry"]) == 2: + (xmin, ymin), (xmax, ymax) = entry["geometry"] + polygon = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)] + else: + polygon = entry["geometry"] + + # Calculate the bounding box of the word + x_coords, y_coords = zip(*polygon) + xmin, ymin, xmax, ymax = ( + int(round(w * min(x_coords))), + int(round(h * min(y_coords))), + int(round(w * max(x_coords))), + int(round(h * max(y_coords))), + ) + word_width = xmax - xmin + word_height = ymax - ymin + + # If lines are provided instead of words, concatenate the word entries + if "words" in entry: + word_text = " ".join(word["value"] for word in entry["words"]) + else: + word_text = entry["value"] + # Find the optimal font size + try: + font_size = min(word_height, max_font_size) + font = get_font(font_family, font_size) + text_width, text_height = font.getbbox(word_text)[2:4] + + while (text_width > word_width or text_height > word_height) and font_size > min_font_size: + font_size = max(int(font_size * smoothing_factor), min_font_size) + font = get_font(font_family, font_size) + text_width, text_height = font.getbbox(word_text)[2:4] + except ValueError: + font = get_font(font_family, min_font_size) + + # Create a mask for the word + mask = Image.new("L", (w, h), 0) + ImageDraw.Draw(mask).polygon([(int(round(w * x)), int(round(h * y))) for x, y in polygon], fill=255) + + # Draw the word text + d = ImageDraw.Draw(response) + try: + try: + d.text((xmin, ymin), word_text, font=font, fill=(0, 0, 0), anchor="lt") + except UnicodeEncodeError: + d.text((xmin, ymin), anyascii(word_text), font=font, fill=(0, 0, 0), anchor="lt") + # Catch generic exceptions to avoid crashing the whole rendering + except Exception: # pragma: no cover + logging.warning(f"Could not render word: {word_text}") + + if draw_proba: + confidence = ( + entry["confidence"] + if "confidence" in entry + else sum(w["confidence"] for w in entry["words"]) / len(entry["words"]) + ) + p = int(255 * confidence) + color = (255 - p, 0, p) # Red to blue gradient based on probability + d.rectangle([(xmin, ymin), (xmax, ymax)], outline=color, width=2) + + prob_font = get_font(font_family, 20) + prob_text = f"{confidence:.2f}" + prob_text_width, prob_text_height = prob_font.getbbox(prob_text)[2:4] + + # Position the probability slightly above the bounding box + prob_x_offset = (word_width - prob_text_width) // 2 + prob_y_offset = ymin - prob_text_height - 2 + prob_y_offset = max(0, prob_y_offset) + + d.text((xmin + prob_x_offset, prob_y_offset), prob_text, font=prob_font, fill=color, anchor="lt") + + return response + + def synthesize_page( page: Dict[str, Any], draw_proba: bool = False, font_family: Optional[str] = None, + smoothing_factor: float = 0.95, + min_font_size: int = 8, + max_font_size: int = 50, ) -> np.ndarray: """Draw a the content of the element page (OCR response) on a blank page. @@ -24,8 +124,10 @@ def synthesize_page( ---- page: exported Page object to represent draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 - font_size: size of the font, default font = 13 font_family: family of the font + smoothing_factor: factor to smooth the font size + min_font_size: minimum font size + max_font_size: maximum font size Returns: ------- @@ -33,41 +135,42 @@ def synthesize_page( """ # Draw template h, w = page["dimensions"] - response = 255 * np.ones((h, w, 3), dtype=np.int32) + response = Image.new("RGB", (w, h), color=(255, 255, 255)) - # Draw each word for block in page["blocks"]: - for line in block["lines"]: - for word in line["words"]: - # Get absolute word geometry - (xmin, ymin), (xmax, ymax) = word["geometry"] - xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) - ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) - - # White drawing context adapted to font size, 0.75 factor to convert pts --> pix - font = get_font(font_family, int(0.75 * (ymax - ymin))) - img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) - d = ImageDraw.Draw(img) - # Draw in black the value of the word - try: - d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) - except UnicodeEncodeError: - # When character cannot be encoded, use its anyascii version - d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0)) - - # Colorize if draw_proba - if draw_proba: - p = int(255 * word["confidence"]) - mask = np.where(np.array(img) == 0, 1, 0) - proba: np.ndarray = np.array([255 - p, 0, p]) - color = mask * proba[np.newaxis, np.newaxis, :] - white_mask = 255 * (1 - mask) - img = color + white_mask - - # Write to response page - response[ymin:ymax, xmin:xmax, :] = np.array(img) - - return response + # If lines are provided use these to get better rendering results + if len(block["lines"]) > 1: + for line in block["lines"]: + _warn_rotation(block) # pragma: no cover + response = _synthesize( + response=response, + entry=line, + w=w, + h=h, + draw_proba=draw_proba, + font_family=font_family, + smoothing_factor=smoothing_factor, + min_font_size=min_font_size, + max_font_size=max_font_size, + ) + # Otherwise, draw each word + else: + for line in block["lines"]: + _warn_rotation(block) # pragma: no cover + for word in line["words"]: + response = _synthesize( + response=response, + entry=word, + w=w, + h=h, + draw_proba=draw_proba, + font_family=font_family, + smoothing_factor=smoothing_factor, + min_font_size=min_font_size, + max_font_size=max_font_size, + ) + + return np.array(response, dtype=np.uint8) def synthesize_kie_page( @@ -81,8 +184,10 @@ def synthesize_kie_page( ---- page: exported Page object to represent draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 - font_size: size of the font, default font = 13 font_family: family of the font + smoothing_factor: factor to smooth the font size + min_font_size: minimum font size + max_font_size: maximum font size Returns: ------- @@ -90,37 +195,18 @@ def synthesize_kie_page( """ # Draw template h, w = page["dimensions"] - response = 255 * np.ones((h, w, 3), dtype=np.int32) + response = Image.new("RGB", (w, h), color=(255, 255, 255)) # Draw each word for predictions in page["predictions"].values(): for prediction in predictions: - # Get aboslute word geometry - (xmin, ymin), (xmax, ymax) = prediction["geometry"] - xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) - ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) - - # White drawing context adapted to font size, 0.75 factor to convert pts --> pix - font = get_font(font_family, int(0.75 * (ymax - ymin))) - img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) - d = ImageDraw.Draw(img) - # Draw in black the value of the word - try: - d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0)) - except UnicodeEncodeError: - # When character cannot be encoded, use its anyascii version - d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0)) - - # Colorize if draw_proba - if draw_proba: - p = int(255 * prediction["confidence"]) - mask = np.where(np.array(img) == 0, 1, 0) - proba: np.ndarray = np.array([255 - p, 0, p]) - color = mask * proba[np.newaxis, np.newaxis, :] - white_mask = 255 * (1 - mask) - img = color + white_mask - - # Write to response page - response[ymin:ymax, xmin:xmax, :] = np.array(img) - - return response + _warn_rotation(prediction) # pragma: no cover + response = _synthesize( + response=response, + entry=prediction, + w=w, + h=h, + draw_proba=draw_proba, + font_family=font_family, + ) + return np.array(response, dtype=np.uint8) diff --git a/tests/common/test_utils_reconstitution.py b/tests/common/test_utils_reconstitution.py index 3b70e67070..be98db89b2 100644 --- a/tests/common/test_utils_reconstitution.py +++ b/tests/common/test_utils_reconstitution.py @@ -1,12 +1,44 @@ import numpy as np -from test_io_elements import _mock_pages +from test_io_elements import _mock_kie_pages, _mock_pages from doctr.utils import reconstitution def test_synthesize_page(): pages = _mock_pages() - reconstitution.synthesize_page(pages[0].export(), draw_proba=False) - render = reconstitution.synthesize_page(pages[0].export(), draw_proba=True) - assert isinstance(render, np.ndarray) - assert render.shape == (*pages[0].dimensions, 3) + # Test without probability rendering + render_no_proba = reconstitution.synthesize_page(pages[0].export(), draw_proba=False) + assert isinstance(render_no_proba, np.ndarray) + assert render_no_proba.shape == (*pages[0].dimensions, 3) + + # Test with probability rendering + render_with_proba = reconstitution.synthesize_page(pages[0].export(), draw_proba=True) + assert isinstance(render_with_proba, np.ndarray) + assert render_with_proba.shape == (*pages[0].dimensions, 3) + + # Test with only one line + pages_one_line = pages[0].export() + pages_one_line["blocks"][0]["lines"] = [pages_one_line["blocks"][0]["lines"][0]] + render_one_line = reconstitution.synthesize_page(pages_one_line, draw_proba=True) + assert isinstance(render_one_line, np.ndarray) + assert render_one_line.shape == (*pages[0].dimensions, 3) + + # Test with polygons + pages_poly = pages[0].export() + pages_poly["blocks"][0]["lines"][0]["geometry"] = [(0, 0), (0, 1), (1, 1), (1, 0)] + render_poly = reconstitution.synthesize_page(pages_poly, draw_proba=True) + assert isinstance(render_poly, np.ndarray) + assert render_poly.shape == (*pages[0].dimensions, 3) + + +def test_synthesize_kie_page(): + pages = _mock_kie_pages() + # Test without probability rendering + render_no_proba = reconstitution.synthesize_kie_page(pages[0].export(), draw_proba=False) + assert isinstance(render_no_proba, np.ndarray) + assert render_no_proba.shape == (*pages[0].dimensions, 3) + + # Test with probability rendering + render_with_proba = reconstitution.synthesize_kie_page(pages[0].export(), draw_proba=True) + assert isinstance(render_with_proba, np.ndarray) + assert render_with_proba.shape == (*pages[0].dimensions, 3)