Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactoring rotated boxes #731

Merged
merged 36 commits into from
Dec 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aa4e253
refacto: rboxes
charlesmindee Dec 16, 2021
dc95b33
feat: builder modifications
charlesmindee Dec 17, 2021
3b3418f
fix: rotate_abs_boxes
charlesmindee Dec 17, 2021
d940277
refacto: viz + metrics
charlesmindee Dec 17, 2021
5090582
fix: merging issues
charlesmindee Dec 20, 2021
29d9d53
fix: flake8 + typing
charlesmindee Dec 20, 2021
17c8572
Merge branch 'main' into refacto_polys
charlesmindee Dec 20, 2021
2903f02
fix: debug 1
charlesmindee Dec 20, 2021
35332e6
fix: debug 2 tests
charlesmindee Dec 21, 2021
e0e32c7
fix: debug test 3
charlesmindee Dec 21, 2021
2876eb9
fix: requested changes
charlesmindee Dec 21, 2021
6a70b70
Merge branch 'main' into refacto_polys
charlesmindee Dec 21, 2021
adb95bf
fix: test rotate
charlesmindee Dec 21, 2021
183ce86
fix: debug 4
charlesmindee Dec 22, 2021
a934869
fix: cv2 tests metrics
charlesmindee Dec 22, 2021
a1b76b5
fix: revert changes
charlesmindee Dec 22, 2021
ad593d2
fix: isort
charlesmindee Dec 23, 2021
3db928f
fix: thresh for rotated ckpt
charlesmindee Dec 23, 2021
b388bf5
fix: utils + scripts
charlesmindee Dec 23, 2021
6ae8240
fix: warnings
charlesmindee Dec 23, 2021
cef1931
fix: tests
charlesmindee Dec 23, 2021
8772c63
Merge branch 'main' into refacto_polys
charlesmindee Dec 23, 2021
2e29f9f
fix: tf zoo tests
charlesmindee Dec 23, 2021
badbb6a
Merge branch 'main' into refacto_polys
charlesmindee Dec 23, 2021
213ac99
fix: requested changes 1
charlesmindee Dec 23, 2021
5f38ab4
fix: tests
charlesmindee Dec 24, 2021
6877601
fix: tests
charlesmindee Dec 24, 2021
3552a97
fix: scripts
charlesmindee Dec 24, 2021
bbf689d
fix: dataset stack
charlesmindee Dec 24, 2021
445f9d4
fix: typing
charlesmindee Dec 24, 2021
6bd4c59
fix: test zoo tf
charlesmindee Dec 24, 2021
17f1758
fix: builder
charlesmindee Dec 24, 2021
a6f01c8
fix: merging conflicts
charlesmindee Dec 24, 2021
4f99940
fix: tests merging conflicts
charlesmindee Dec 24, 2021
219661f
fix: minor changes
charlesmindee Dec 26, 2021
715040a
fix: crop extraction fn
charlesmindee Dec 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: debug 4
  • Loading branch information
charlesmindee committed Dec 22, 2021
commit 183ce86c9c0517210eff0453643cc0db3b9a7fe4
7 changes: 4 additions & 3 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,13 @@ def extract_rcrops(
for box in _boxes:
src_pts = box[1:, :].astype(np.float32)
# Preserve size
_, (w, h), _ = cv2.minAreaRect(box)
dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1]], dtype=dtype)
d1 = np.linalg.norm(src_pts[0, :] - src_pts[1, :], axis=-1)
d2 = np.linalg.norm(src_pts[1, :] - src_pts[2, :], axis=-1)
dst_pts = np.array([[0, 0], [d1 - 1, 0], [d1 - 1, d2 - 1]], dtype=dtype)
# The transformation matrix
M = cv2.getAffineTransform(src_pts, dst_pts)
# Warp the rotated rectangle
crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(w), int(h)))
crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(d1), int(d2)))
crops.append(crop)

return crops
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def bitmap_to_boxes(
if self.assume_straight_pages:
if _box is None or _box[2] < min_size_box or _box[3] < min_size_box:
continue
elif abs(_box[0, 0] - _box[2, 0]) < min_size_box or abs(_box[0, 1] - _box[2, 1]) < min_size_box:
elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box:
continue

if self.assume_straight_pages:
Expand Down
6 changes: 4 additions & 2 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ def _dbnet(arch: str, pretrained: bool, pretrained_backbone: bool = True, **kwar
model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]['url'])

state_dict = torch.load("/home/laptopmindee/Téléchargements/db_resnet50_20211130-090054.pt", map_location='cpu')
# Load weights
model.load_state_dict(state_dict)

return model


Expand Down
2 changes: 2 additions & 0 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.utils.geometry import rotate_boxes, rotate_image
from ..classification import crop_orientation_predictor

from .base import _OCRPredictor

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
self.doc_builder = DocumentBuilder(export_as_straight_boxes=export_as_straight_boxes)
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, but since it's a predictor, let's ensure we are in eval mode:
could you add .eval() at the end of the line please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still missing the .eval(), let's make sure we have all models in eval mode in a predictor


@torch.no_grad()
def forward(
Expand Down
2 changes: 2 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.utils.geometry import rotate_boxes, rotate_image
from doctr.utils.repr import NestedObject
from ..classification import crop_orientation_predictor

from .base import _OCRPredictor

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
self.doc_builder = DocumentBuilder(export_as_straight_boxes=export_as_straight_boxes)
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True)

def __call__(
self,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _predictor(
det_arch,
pretrained=pretrained,
batch_size=det_bs,
assume_straight_pages=assume_straight_pages
assume_straight_pages=assume_straight_pages,
)

# Recognition
Expand Down
5 changes: 2 additions & 3 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def create_obj_patch(
if isinstance(geometry, tuple):
if len(geometry) == 2:
return rect_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type]
elif isinstance(geometry, np.ndarray):
return polygon_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type]

elif len(geometry) == 4:
return polygon_patch(np.asarray(geometry), page_dimensions, **kwargs) # type: ignore[arg-type]
raise ValueError("invalid geometry format")


Expand Down