Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactoring rotated boxes #731

Merged
merged 36 commits into from
Dec 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aa4e253
refacto: rboxes
charlesmindee Dec 16, 2021
dc95b33
feat: builder modifications
charlesmindee Dec 17, 2021
3b3418f
fix: rotate_abs_boxes
charlesmindee Dec 17, 2021
d940277
refacto: viz + metrics
charlesmindee Dec 17, 2021
5090582
fix: merging issues
charlesmindee Dec 20, 2021
29d9d53
fix: flake8 + typing
charlesmindee Dec 20, 2021
17c8572
Merge branch 'main' into refacto_polys
charlesmindee Dec 20, 2021
2903f02
fix: debug 1
charlesmindee Dec 20, 2021
35332e6
fix: debug 2 tests
charlesmindee Dec 21, 2021
e0e32c7
fix: debug test 3
charlesmindee Dec 21, 2021
2876eb9
fix: requested changes
charlesmindee Dec 21, 2021
6a70b70
Merge branch 'main' into refacto_polys
charlesmindee Dec 21, 2021
adb95bf
fix: test rotate
charlesmindee Dec 21, 2021
183ce86
fix: debug 4
charlesmindee Dec 22, 2021
a934869
fix: cv2 tests metrics
charlesmindee Dec 22, 2021
a1b76b5
fix: revert changes
charlesmindee Dec 22, 2021
ad593d2
fix: isort
charlesmindee Dec 23, 2021
3db928f
fix: thresh for rotated ckpt
charlesmindee Dec 23, 2021
b388bf5
fix: utils + scripts
charlesmindee Dec 23, 2021
6ae8240
fix: warnings
charlesmindee Dec 23, 2021
cef1931
fix: tests
charlesmindee Dec 23, 2021
8772c63
Merge branch 'main' into refacto_polys
charlesmindee Dec 23, 2021
2e29f9f
fix: tf zoo tests
charlesmindee Dec 23, 2021
badbb6a
Merge branch 'main' into refacto_polys
charlesmindee Dec 23, 2021
213ac99
fix: requested changes 1
charlesmindee Dec 23, 2021
5f38ab4
fix: tests
charlesmindee Dec 24, 2021
6877601
fix: tests
charlesmindee Dec 24, 2021
3552a97
fix: scripts
charlesmindee Dec 24, 2021
bbf689d
fix: dataset stack
charlesmindee Dec 24, 2021
445f9d4
fix: typing
charlesmindee Dec 24, 2021
6bd4c59
fix: test zoo tf
charlesmindee Dec 24, 2021
17f1758
fix: builder
charlesmindee Dec 24, 2021
a6f01c8
fix: merging conflicts
charlesmindee Dec 24, 2021
4f99940
fix: tests merging conflicts
charlesmindee Dec 24, 2021
219661f
fix: minor changes
charlesmindee Dec 26, 2021
715040a
fix: crop extraction fn
charlesmindee Dec 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: builder modifications
  • Loading branch information
charlesmindee committed Dec 17, 2021
commit dc95b3314de44a0072410189dcec556373938e3b
17 changes: 9 additions & 8 deletions doctr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.cluster.hierarchy import fclusterdata

from doctr.io.elements import Block, Document, Line, Page, Word
from doctr.utils.geometry import rbbox_to_polygon, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
from doctr.utils.repr import NestedObject

__all__ = ['DocumentBuilder']
Expand Down Expand Up @@ -57,7 +57,8 @@ def _sort_boxes(boxes: np.ndarray) -> np.ndarray:
orig_shape=(1024, 1024),
min_angle=5.,
)
return (boxes[:, 0] + 2 * boxes[:, 1] / np.median(boxes[:, 3])).argsort()
# Points are in this order: top left, top right, bot right, bot left
return (boxes[:, 0, 0] + 2 * boxes[:, 1, 2] / np.median(boxes[:, 1, 2] - boxes[:, 1, 1])).argsort()
return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort()

def _resolve_sub_lines(self, boxes: np.ndarray, words: List[int]) -> List[List[int]]:
Expand All @@ -84,7 +85,7 @@ def _resolve_sub_lines(self, boxes: np.ndarray, words: List[int]) -> List[List[i
prev_box = boxes[sub_line[-1]]
# Compute distance between boxes
if len(boxes.shape) == 3:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
dist = boxes[i, 0] - prev_box[2] / 2 - (prev_box[0] + prev_box[2] / 2)
dist = boxes[i, 0, 0] - prev_box[0, 1]
else:
dist = boxes[i, 0] - prev_box[2]
# If distance between boxes is lower than paragraph break, same sub-line
Expand All @@ -110,7 +111,7 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
nested list of box indices
"""
# Compute median for boxes heights
y_med = np.median(boxes[:, 3] if len(boxes.shape) == 3 else boxes[:, 3] - boxes[:, 1])
y_med = np.median(boxes[:, 1, 2] - boxes[:, 1, 1] if len(boxes.shape) == 3 else boxes[:, 3] - boxes[:, 1])

# Sort boxes
idxs = self._sort_boxes(boxes)
Expand All @@ -119,7 +120,7 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
words = [idxs[0]] # Assign the top-left word to the first line
# Define a mean y-center for the line
if len(boxes.shape) == 3:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
y_center_sum = boxes[idxs[0]][1]
y_center_sum = boxes[idxs[0]][[[1, 2], [1, 1]]].mean()
else:
y_center_sum = boxes[idxs[0]][[1, 3]].mean()

Expand All @@ -128,7 +129,7 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:

# Compute y_dist
if len(boxes.shape) == 3:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
y_dist = abs(boxes[idx][1] - y_center_sum / len(words))
y_dist = abs(boxes[idx][[[1, 2], [1, 1]]].mean() - y_center_sum / len(words))
else:
y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words))
# If y-center of the box is close enough to mean y-center of the line, same line
Expand All @@ -142,7 +143,7 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
y_center_sum = 0

words.append(idx)
y_center_sum += boxes[idx][1 if len(boxes.shape) == 3 else [1, 3]].mean()
y_center_sum += boxes[idx][[[1, 2], [1, 1]] if len(boxes.shape) == 3 else [1, 3]].mean()

# Use the remaining words to form the last(s) line(s)
if len(words) > 0:
Expand Down Expand Up @@ -241,7 +242,7 @@ def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]])
[
Word(
*word_preds[idx],
(boxes[idx, 0], boxes[idx, 1], boxes[idx, 2], boxes[idx, 3], boxes[idx, 4])
tuple(boxes[idx].tolist())
) if len(boxes.shape) == 3 else
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
Word(
*word_preds[idx],
Expand Down
10 changes: 7 additions & 3 deletions doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P:
return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1]


# TODO: deprecation warning (a rbbox is now a polygon)
def rbbox_to_polygon(rbbox: RotatedBbox) -> Polygon4P:
(x, y, w, h, alpha) = rbbox
return cv2.boxPoints(((float(x), float(y)), (float(w), float(h)), -float(alpha)))


# TODO: deprecation warning (a rbbox is now a polygon)
def fit_rbbox(pts: np.ndarray) -> RotatedBbox:
((x, y), (w, h), alpha) = cv2.minAreaRect(pts)
return x, y, h, w, 90 - alpha
Expand All @@ -35,6 +37,7 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox:
return (min(x), min(y)), (max(x), max(y))


# TODO: deprecation warning (a rbbox is now a polygon)
def polygon_to_rbbox(polygon: Polygon4P) -> RotatedBbox:
cnt = np.array(polygon).reshape((-1, 1, 2)).astype(np.float32)
return fit_rbbox(cnt)
Expand All @@ -58,9 +61,10 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio
return (min(x), min(y)), (max(x), max(y))


def resolve_enclosing_rbbox(rbboxes: List[RotatedBbox]) -> RotatedBbox:
pts = np.asarray([pt for rbbox in rbboxes for pt in rbbox_to_polygon(rbbox)], np.float32)
return fit_rbbox(pts)
def resolve_enclosing_rbbox(rbboxes: List[RotatedBbox]) -> Polygon4P:
cloud = np.concatenate(rbboxes, axis=0)
rect = cv2.minAreaRect(cloud)
return cv2.boxPoints(rect)


def rotate_abs_points(points: np.ndarray, angle: float = 0.) -> np.ndarray:
Expand Down