Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature][tf/pt] integrate from_hub for all tasks #892

Merged
merged 34 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
81c313e
backup
felixdittrich92 Jan 11, 2022
50574b5
Merge branch 'mindee:main' into main
felixdittrich92 Jan 11, 2022
5a6ed54
Merge branch 'mindee:main' into main
felixdittrich92 Jan 18, 2022
b9958a7
Merge branch 'mindee:main' into main
felixdittrich92 Jan 20, 2022
14c4651
Merge branch 'mindee:main' into main
felixdittrich92 Feb 16, 2022
779731f
Merge branch 'mindee:main' into main
felixdittrich92 Feb 18, 2022
ce2cdda
Merge branch 'mindee:main' into main
felixdittrich92 Feb 22, 2022
d13dc43
Merge branch 'mindee:main' into main
felixdittrich92 Feb 23, 2022
9a07d73
Merge branch 'mindee:main' into main
felixdittrich92 Feb 24, 2022
a002a70
Merge branch 'mindee:main' into main
felixdittrich92 Feb 24, 2022
6ad096e
Merge branch 'mindee:main' into main
felixdittrich92 Feb 25, 2022
1e77fd4
Merge branch 'mindee:main' into main
felixdittrich92 Mar 8, 2022
2be762c
Merge branch 'mindee:main' into main
felixdittrich92 Mar 10, 2022
e2f2055
Merge branch 'mindee:main' into main
felixdittrich92 Mar 11, 2022
bdc4e67
Merge branch 'mindee:main' into main
felixdittrich92 Mar 16, 2022
b525021
Merge branch 'mindee:main' into main
felixdittrich92 Mar 16, 2022
417a27b
Merge branch 'mindee:main' into main
felixdittrich92 Mar 16, 2022
9b3f5a1
Merge branch 'mindee:main' into main
felixdittrich92 Mar 18, 2022
93074a8
Merge branch 'mindee:main' into main
felixdittrich92 Mar 21, 2022
c64e209
Merge branch 'mindee:main' into main
felixdittrich92 Mar 22, 2022
fdc8381
Merge branch 'mindee:main' into main
felixdittrich92 Mar 25, 2022
bd68b07
Merge branch 'mindee:main' into main
felixdittrich92 Apr 5, 2022
7ac6ee2
Merge branch 'mindee:main' into main
felixdittrich92 Apr 5, 2022
1c79f32
Merge branch 'mindee:main' into main
felixdittrich92 Apr 7, 2022
45e43ac
Merge branch 'mindee:main' into main
felixdittrich92 Apr 13, 2022
5c5c01b
start
felixdittrich92 Apr 13, 2022
90919a1
update tests
felixdittrich92 Apr 14, 2022
b8ce17c
fix missing classification model check
felixdittrich92 Apr 14, 2022
1bf00f0
add torch dummys
felixdittrich92 Apr 14, 2022
be6191e
add tf dummys and fix loading
felixdittrich92 Apr 14, 2022
473252c
fix loading
felixdittrich92 Apr 14, 2022
7c91e50
fix correct cfg loading
felixdittrich92 Apr 15, 2022
0ae8eea
fix correct cfg loading
felixdittrich92 Apr 15, 2022
70c3749
format and missing test
felixdittrich92 Apr 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Any
from typing import Any, List

from doctr.file_utils import is_tf_available, is_torch_available
from doctr.file_utils import is_tf_available

from .. import classification
from ..preprocessor import PreProcessor
from .predictor import CropOrientationPredictor

__all__ = ["crop_orientation_predictor"]


if is_tf_available():
ARCHS = ['mobilenet_v3_small_orientation']
elif is_torch_available():
ARCHS = ['mobilenet_v3_small_orientation']
ARCHS: List[str] = [
'magc_resnet31',
'mobilenet_v3_small', 'mobilenet_v3_small_r', 'mobilenet_v3_large', 'mobilenet_v3_large_r',
'resnet18', 'resnet31', 'resnet34', 'resnet50', 'resnet34_wide',
'vgg16_bn_r'
]
ORIENTATION_ARCHS: List[str] = ['mobilenet_v3_small_orientation']


def _crop_orientation_predictor(
Expand All @@ -26,7 +28,7 @@ def _crop_orientation_predictor(
**kwargs: Any
) -> CropOrientationPredictor:

if arch not in ARCHS:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

# Load directly classifier from backbone
Expand Down
5 changes: 3 additions & 2 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@


if is_tf_available():
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'linknet_resnet18_rotation']
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'linknet_resnet34', 'linknet_resnet50']
ROT_ARCHS = ['linknet_resnet18_rotation']
elif is_torch_available():
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'db_resnet50_rotation']
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18',
'linknet_resnet34', 'linknet_resnet50']
ROT_ARCHS = ['db_resnet50_rotation']


Expand Down
80 changes: 70 additions & 10 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@
from pathlib import Path
from typing import Any

from huggingface_hub import HfApi, HfFolder, Repository
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, snapshot_download

from doctr import models
from doctr.file_utils import is_tf_available, is_torch_available

from ..detection import zoo as det_zoo
from ..recognition import zoo as reco_zoo

if is_torch_available():
import torch

__all__ = ['login_to_hub', 'push_to_hf_hub', '_save_model_and_config_for_hf_hub']
__all__ = ['login_to_hub', 'push_to_hf_hub', 'from_hub', '_save_model_and_config_for_hf_hub']


AVAILABLE_ARCHS = {
'detection': det_zoo.ARCHS + det_zoo.ROT_ARCHS,
'recognition': reco_zoo.ARCHS,
'classification': models.classification.zoo.ARCHS,
'detection': models.detection.zoo.ARCHS + models.detection.zoo.ROT_ARCHS,
'recognition': models.recognition.zoo.ARCHS,
'obj_detection': ['fasterrcnn_mobilenet_v3_large_fpn'] if is_torch_available() else None
}

Expand Down Expand Up @@ -124,12 +123,11 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None:

```python
>>> from doctr.io import DocumentFile
>>> from doctr.models import ocr_predictor
>>> from doctr.models.<task> import from_hub
>>> from doctr.models import ocr_predictor, from_hub

>>> img = DocumentFile.from_images(['<image_path>'])
>>> # Load your model from the hub
>>> model = from_hub('mindee/my-model').eval()
>>> model = from_hub('mindee/my-model')

>>> # Pass it to the predictor
>>> # If your model is a recognition model:
Expand Down Expand Up @@ -170,3 +168,65 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None:
readme_path.write_text(readme)

repo.git_push()


def from_hub(repo_id: str, **kwargs: Any):
"""Instantiate & load a pretrained model from HF hub.

>>> from doctr.models import from_hub
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")

Args:
repo_id: HuggingFace model hub repo
kwargs: kwargs of `hf_hub_download` or `snapshot_download`

Returns:
Model loaded with the checkpoint
"""

# Get the config
with open(hf_hub_download(repo_id, filename='config.json', **kwargs), 'rb') as f:
cfg = json.load(f)

arch = cfg['arch']
task = cfg['task']
cfg.pop('arch')
cfg.pop('task')

if task == 'classification':
model = models.classification.__dict__[arch](
pretrained=False,
classes=cfg['classes'],
num_classes=cfg['num_classes']
)
elif task == 'detection':
model = models.detection.__dict__[arch](
pretrained=False
)
elif task == 'recognition':
model = models.recognition.__dict__[arch](
pretrained=False,
input_shape=cfg['input_shape'],
vocab=cfg['vocab']
)
elif task == 'obj_detection' and is_torch_available():
model = models.obj_detection.__dict__[arch](
pretrained=False,
image_mean=cfg['mean'],
image_std=cfg['std'],
max_size=cfg['input_shape'][-1],
num_classes=len(cfg['classes']),
)

# update model cfg
model.cfg = cfg

# Load checkpoint
if is_torch_available():
state_dict = torch.load(hf_hub_download(repo_id, filename='pytorch_model.bin', **kwargs), map_location='cpu')
model.load_state_dict(state_dict)
else: # tf
repo_path = snapshot_download(repo_id, **kwargs)
model.load_weights(os.path.join(repo_path, 'tf_model', 'weights'))

return model
4 changes: 0 additions & 4 deletions doctr/models/obj_detection/factory/__init__.py

This file was deleted.

50 changes: 0 additions & 50 deletions doctr/models/obj_detection/factory/pytorch.py

This file was deleted.

3 changes: 3 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def test_classification_zoo(arch_name):
# Model
predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False)
predictor.model.eval()

with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch='wrong_model', pretrained=False)
# object check
assert isinstance(predictor, CropOrientationPredictor)
input_tensor = torch.rand((batch_size, 3, 128, 128))
Expand Down
68 changes: 33 additions & 35 deletions tests/pytorch/test_models_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import pytest

from doctr.models import classification, detection, obj_detection, recognition
from doctr.models.factory import _save_model_and_config_for_hf_hub, push_to_hf_hub
from doctr import models
from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub


def test_push_to_hf_hub():
model = classification.resnet18(pretrained=False)
model = models.classification.resnet18(pretrained=False)
with pytest.raises(ValueError):
# run_config and/or arch must be specified
push_to_hf_hub(model, model_name='test', task='classification')
Expand All @@ -22,42 +22,36 @@ def test_push_to_hf_hub():


@pytest.mark.parametrize(
"arch_name, task_name",
"arch_name, task_name, dummy_model_id",
[
["vgg16_bn_r", "classification"],
["resnet18", "classification"],
["resnet31", "classification"],
["resnet34", "classification"],
["resnet34_wide", "classification"],
["resnet50", "classification"],
["magc_resnet31", "classification"],
["mobilenet_v3_small", "classification"],
["mobilenet_v3_large", "classification"],
["db_resnet34", "detection"],
["db_resnet50", "detection"],
["db_mobilenet_v3_large", "detection"],
["db_resnet50_rotation", "detection"],
["linknet_resnet18", "detection"],
["linknet_resnet34", "detection"],
["linknet_resnet50", "detection"],
["crnn_vgg16_bn", "recognition"],
["crnn_mobilenet_v3_small", "recognition"],
["crnn_mobilenet_v3_large", "recognition"],
["sar_resnet31", "recognition"],
["master", "recognition"],
["fasterrcnn_mobilenet_v3_large_fpn", "obj_detection"],
["vgg16_bn_r", "classification", "Felix92/doctr-dummy-torch-vgg16-bn-r"],
["resnet18", "classification", "Felix92/doctr-dummy-torch-resnet18"],
["resnet31", "classification", "Felix92/doctr-dummy-torch-resnet31"],
["resnet34", "classification", "Felix92/doctr-dummy-torch-resnet34"],
["resnet34_wide", "classification", "Felix92/doctr-dummy-torch-resnet34-wide"],
["resnet50", "classification", "Felix92/doctr-dummy-torch-resnet50"],
["magc_resnet31", "classification", "Felix92/doctr-dummy-torch-magc-resnet31"],
["mobilenet_v3_small", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-small"],
["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-large"],
["db_resnet34", "detection", "Felix92/doctr-dummy-torch-db-resnet34"],
["db_resnet50", "detection", "Felix92/doctr-dummy-torch-db-resnet50"],
["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-torch-db-mobilenet-v3-large"],
["db_resnet50_rotation", "detection", "Felix92/doctr-dummy-torch-db-resnet50-rotation"],
["linknet_resnet18", "detection", "Felix92/doctr-dummy-torch-linknet-resnet18"],
["linknet_resnet34", "detection", "Felix92/doctr-dummy-torch-linknet-resnet34"],
["linknet_resnet50", "detection", "Felix92/doctr-dummy-torch-linknet-resnet50"],
["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-torch-crnn-vgg16-bn"],
["crnn_mobilenet_v3_small", "recognition", "Felix92/doctr-dummy-torch-crnn-mobilenet-v3-small"],
["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-torch-crnn-mobilenet-v3-large"],
# ["sar_resnet31", "recognition", ""], enable after model is fixed !
# ["master", "recognition", ""], enable after model is fixed !
["fasterrcnn_mobilenet_v3_large_fpn", "obj_detection",
"Felix92/doctr-dummy-torch-fasterrcnn-mobilenet-v3-large-fpn"],
],
)
def test_models_for_hub(arch_name, task_name, tmpdir):
def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir):
with tempfile.TemporaryDirectory() as tmp_dir:
if task_name == "classification":
model = classification.__dict__[arch_name](pretrained=True).eval()
elif task_name == "detection":
model = detection.__dict__[arch_name](pretrained=True).eval()
elif task_name == "recognition":
model = recognition.__dict__[arch_name](pretrained=True).eval()
elif task_name == "obj_detection":
model = obj_detection.__dict__[arch_name](pretrained=True).eval()
model = models.__dict__[task_name].__dict__[arch_name](pretrained=True).eval()

_save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir)

Expand All @@ -69,3 +63,7 @@ def test_models_for_hub(arch_name, task_name, tmpdir):
assert arch_name == tmp_config['arch']
assert task_name == tmp_config['task']
assert all(key in model.cfg.keys() for key in tmp_config.keys())

# test from hub
hub_model = from_hub(repo_id=dummy_model_id)
assert isinstance(hub_model, type(model))
7 changes: 0 additions & 7 deletions tests/pytorch/test_models_obj_detection_pt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest
import torch
from torchvision.models.detection import FasterRCNN

from doctr.models import obj_detection
from doctr.models.obj_detection.factory import from_hub


@pytest.mark.parametrize(
Expand Down Expand Up @@ -34,8 +32,3 @@ def test_detection_models(arch_name, input_shape, pretrained):
target = [{k: v.cuda() for k, v in t.items()} for t in target]
out = model(input_tensor, target)
assert isinstance(out, dict) and all(isinstance(v, torch.Tensor) for v in out.values())


def test_obj_det_from_hub():
model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
assert isinstance(model, FasterRCNN)
2 changes: 2 additions & 0 deletions tests/tensorflow/test_models_classification_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_classification_zoo(arch_name):
batch_size = 16
# Model
predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False)
with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch='wrong_model', pretrained=False)
# object check
assert isinstance(predictor, CropOrientationPredictor)
input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1)
Expand Down
Loading