Skip to content

Commit

Permalink
[Feat] Predictor precision PT backend (#1204)
Browse files Browse the repository at this point in the history
felixdittrich92 authored Jun 20, 2023
1 parent fdd00a3 commit 31f05c8
Showing 7 changed files with 72 additions and 16 deletions.
8 changes: 6 additions & 2 deletions doctr/models/classification/predictor/pytorch.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from torch import nn

from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype

__all__ = ["CropOrientationPredictor"]

@@ -42,8 +43,11 @@ def forward(
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")

processed_batches = self.pre_processor(crops)
_device = next(self.model.parameters()).device
predicted_batches = [self.model(batch.to(device=_device)).to(device=_device) for batch in processed_batches]
_params = next(self.model.parameters())
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
)
predicted_batches = [self.model(batch) for batch in processed_batches]

# Postprocess predictions
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
10 changes: 6 additions & 4 deletions doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from torch import nn

from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype

__all__ = ["DetectionPredictor"]

@@ -42,8 +43,9 @@ def forward(
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")

processed_batches = self.pre_processor(pages)
_device = next(self.model.parameters()).device
predicted_batches = [
self.model(batch.to(device=_device), return_preds=True, **kwargs)["preds"] for batch in processed_batches
]
_params = next(self.model.parameters())
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
)
predicted_batches = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
return [pred for batch in predicted_batches for pred in batch]
5 changes: 4 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
@@ -70,7 +70,10 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,


def detection_predictor(
arch: Any = "db_resnet50", pretrained: bool = False, assume_straight_pages: bool = True, **kwargs: Any
arch: Any = "db_resnet50",
pretrained: bool = False,
assume_straight_pages: bool = True,
**kwargs: Any,
) -> DetectionPredictor:
"""Text detection architecture.
10 changes: 6 additions & 4 deletions doctr/models/recognition/predictor/pytorch.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from torch import nn

from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype

from ._utils import remap_preds, split_crops

@@ -68,10 +69,11 @@ def forward(
processed_batches = self.pre_processor(crops)

# Forward it
_device = next(self.model.parameters()).device
raw = [
self.model(batch.to(device=_device), return_preds=True, **kwargs)["preds"] for batch in processed_batches
]
_params = next(self.model.parameters())
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
)
raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]

# Process outputs
out = [charseq for batch in raw for charseq in batch]
29 changes: 27 additions & 2 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
@@ -4,14 +4,14 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import logging
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple, Union

import torch
from torch import nn

from doctr.utils.data import download_from_url

__all__ = ["load_pretrained_params", "conv_sequence_pt", "export_model_to_onnx"]
__all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx"]


def load_pretrained_params(
@@ -90,6 +90,31 @@ def conv_sequence_pt(
return conv_seq


def set_device_and_dtype(
model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype
) -> Tuple[Any, List[torch.Tensor]]:
"""Set the device and dtype of a model and its batches
>>> import torch
>>> from torch import nn
>>> from doctr.models.utils import set_device_and_dtype
>>> model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4))
>>> batches = [torch.rand(8) for _ in range(2)]
>>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
Args:
model: the model to be set
batches: the batches to be set
device: the device to be used
dtype: the dtype to be used
Returns:
the model and batches set
"""

return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]


def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor, **kwargs: Any) -> str:
"""Export model to ONNX format.
10 changes: 8 additions & 2 deletions doctr/models/zoo.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,10 @@ def _predictor(

# Recognition
reco_predictor = recognition_predictor(
reco_arch, pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs
reco_arch,
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
batch_size=reco_bs,
)

return OCRPredictor(
@@ -142,7 +145,10 @@ def _kie_predictor(

# Recognition
reco_predictor = recognition_predictor(
reco_arch, pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs
reco_arch,
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
batch_size=reco_bs,
)

return KIEPredictor(
16 changes: 15 additions & 1 deletion tests/pytorch/test_models_utils_pt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os

import pytest
import torch
from torch import nn

from doctr.models.utils import conv_sequence_pt, load_pretrained_params
from doctr.models.utils import conv_sequence_pt, load_pretrained_params, set_device_and_dtype


def test_load_pretrained_params(tmpdir_factory):
@@ -32,3 +33,16 @@ def test_conv_sequence():
assert len(conv_sequence_pt(3, 8, True, kernel_size=3)) == 2
assert len(conv_sequence_pt(3, 8, False, True, kernel_size=3)) == 2
assert len(conv_sequence_pt(3, 8, True, True, kernel_size=3)) == 3


def test_set_device_and_dtype():
model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4))
batches = [torch.rand(8) for _ in range(2)]
model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float32)
assert model[0].weight.device == torch.device("cpu")
assert model[0].weight.dtype == torch.float32
assert batches[0].device == torch.device("cpu")
assert batches[0].dtype == torch.float32
model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16)
assert model[0].weight.dtype == torch.float16
assert batches[0].dtype == torch.float16

0 comments on commit 31f05c8

Please sign in to comment.