Skip to content

Commit

Permalink
refactor: Cleaned rotation transforms (#536)
Browse files Browse the repository at this point in the history
* fix: Fixed box rotation

* refactor: Reflected changes on rotations in TF & Pytorch

* fix: Fixed rotation

* test: Updated unittests

* docs: Added illustration in doc of RandomRotate

* refactor: Updated RandomRotate args
  • Loading branch information
fg-mindee authored Oct 22, 2021
1 parent 323d484 commit 223abaf
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 35 deletions.
7 changes: 1 addition & 6 deletions doctr/transforms/functional/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,7 @@ def rotate(
_boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[1]

# Rotate the boxes: xmin, ymin, xmax, ymax --> x, y, w, h, alpha
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:]) # type: ignore[arg-type]

# Apply the expansion
if expand:
r_boxes[:, 0] += int((rotated_img.shape[2] - img.shape[2]) / 2)
r_boxes[:, 1] += int((rotated_img.shape[1] - img.shape[1]) / 2)
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[1:], expand) # type: ignore[arg-type]

# Convert them to relative
if boxes.dtype != int:
Expand Down
7 changes: 1 addition & 6 deletions doctr/transforms/functional/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,7 @@ def rotate(
_boxes[:, [1, 3]] = _boxes[:, [1, 3]] * img.shape[0]

# Rotate the boxes: xmin, ymin, xmax, ymax --> x, y, w, h, alpha
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1])

# Apply the expansion
if expand:
r_boxes[:, 0] += int((rotated_img.shape[1] - img.shape[1]) / 2)
r_boxes[:, 1] += int((rotated_img.shape[0] - img.shape[0]) / 2)
r_boxes = rotate_abs_boxes(_boxes, angle, img.shape[:-1], expand)

# Convert them to relative
if boxes.dtype != int:
Expand Down
9 changes: 6 additions & 3 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __call__(self, img: Any) -> Any:
class RandomRotate(NestedObject):
"""Randomly rotate a tensor image and its boxes
.. image:: https://github.com/mindee/doctr/releases/download/v0.4.0/rotation_illustration.png
:align: center
Args:
max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
[-max_angle, max_angle]
Expand All @@ -104,10 +107,10 @@ def __init__(self, max_angle: float = 5., expand: bool = False) -> None:
def extra_repr(self) -> str:
return f"max_angle={self.max_angle}, expand={self.expand}"

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
angle = random.uniform(-self.max_angle, self.max_angle)
r_img, r_boxes = F.rotate(img, target["boxes"], angle, self.expand)
return r_img, dict(boxes=r_boxes)
r_img, r_boxes = F.rotate(img, target, angle, self.expand)
return r_img, r_boxes


class RandomCrop(NestedObject):
Expand Down
42 changes: 28 additions & 14 deletions doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,16 @@ def resolve_enclosing_rbbox(rbboxes: List[RotatedBbox]) -> RotatedBbox:
return fit_rbbox(pts)


def rotate_abs_points(points: np.ndarray, center: np.ndarray, angle: float = 0.) -> np.ndarray:

# Y-axis is inverted by convention
rel_points = np.stack((points[:, 0] - center[0], center[1] - points[:, 1]), axis=1)
def rotate_abs_points(points: np.ndarray, angle: float = 0.) -> np.ndarray:
"""Rotate points counter-clockwise"""

angle_rad = angle * np.pi / 180. # compute radian angle for np functions
rotation_mat = np.array([
[np.cos(angle_rad), -np.sin(angle_rad)],
[np.sin(angle_rad), np.cos(angle_rad)]
], dtype=rel_points.dtype)

rotated_rel_points = np.matmul(rel_points, rotation_mat.T)
rotated_rel_points[:, 0] += center[0]
rotated_rel_points[:, 1] = center[1] - rotated_rel_points[:, 1]
], dtype=points.dtype)

return rotated_rel_points
return np.matmul(points, rotation_mat.T)


def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[int, int]:
Expand All @@ -95,33 +89,53 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in
[-img_shape[1] / 2, img_shape[0] / 2],
])

rotated_points = rotate_abs_points(points, np.zeros(2), angle)
rotated_points = rotate_abs_points(points, angle)

wh_shape = 2 * np.abs(rotated_points).max(axis=0)

return wh_shape[1], wh_shape[0]


def rotate_abs_boxes(boxes: np.ndarray, angle: float, img_shape: Tuple[int, int]) -> np.ndarray:
def rotate_abs_boxes(boxes: np.ndarray, angle: float, img_shape: Tuple[int, int], expand: bool = True) -> np.ndarray:
"""Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax) by an angle around the image center.
Args:
boxes: (N, 4) array of absolute coordinate boxes
angle: angle between -90 and +90 degrees
img_shape: the height and width of the image
expand: whether the image should be padded to avoid information loss
Returns:
A batch of rotated boxes (N, 5): (x, y, w, h, alpha) or a batch of straight bounding boxes
"""

# Get box centers
box_centers = np.stack((boxes[:, 0] + boxes[:, 2], boxes[:, 1] + boxes[:, 3]), axis=1) / 2
img_corners = np.array([[0, 0], [0, img_shape[0]], [*img_shape[::-1]], [img_shape[1], 0]], dtype=boxes.dtype)

stacked_points = np.concatenate((img_corners, box_centers), axis=0)
# Y-axis is inverted by conversion
stacked_rel_points = np.stack(
(stacked_points[:, 0] - img_shape[1] / 2, img_shape[0] / 2 - stacked_points[:, 1]),
axis=1
)

# Rotate them around image center
box_centers = rotate_abs_points(box_centers, np.array(img_shape[::-1]) / 2, angle)
rot_points = rotate_abs_points(stacked_rel_points, angle)
rot_corners, rot_centers = rot_points[:4], rot_points[4:]

# Expand the image to fit all the original info
if expand:
new_corners = np.abs(rot_corners).max(axis=0)
rot_centers[:, 0] += new_corners[0]
rot_centers[:, 1] = new_corners[1] - rot_centers[:, 1]
else:
rot_centers[:, 0] += img_shape[1] / 2
rot_centers[:, 1] = img_shape[0] / 2 - rot_centers[:, 1]

# Rotated bbox conversion
rotated_boxes = np.concatenate((
box_centers,
rot_centers,
np.stack((boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]), axis=1),
np.full((boxes.shape[0], 1), angle, dtype=box_centers.dtype)
), axis=1)
Expand Down
10 changes: 7 additions & 3 deletions test/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,18 @@ def test_random_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotator(input_t, dict(boxes=boxes))
r_img, r_boxes = rotator(input_t, boxes)
assert r_img.shape == input_t.shape
assert abs(r_boxes["boxes"][-1, -1]) <= 10.
assert abs(r_boxes[-1, -1]) <= 10.

rotator = RandomRotate(max_angle=10., expand=True)
r_img, r_boxes = rotator(input_t, boxes)
assert r_img.shape != input_t.shape

# FP16 (only on GPU)
if torch.cuda.is_available():
input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda()
r_img, _ = rotator(input_t, dict(boxes=boxes))
r_img, _ = rotator(input_t, boxes)
assert r_img.dtype == torch.float16


Expand Down
10 changes: 7 additions & 3 deletions test/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,17 @@ def test_random_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotator(input_t, dict(boxes=boxes))
r_img, r_boxes = rotator(input_t, boxes)
assert r_img.shape == input_t.shape
assert abs(r_boxes["boxes"][-1, -1]) <= 10.
assert abs(r_boxes[-1, -1]) <= 10.

rotator = T.RandomRotate(max_angle=10., expand=True)
r_img, r_boxes = rotator(input_t, boxes)
assert r_img.shape != input_t.shape

# FP16
input_t = tf.ones((50, 50, 3), dtype=tf.float16)
r_img, _ = rotator(input_t, dict(boxes=boxes))
r_img, _ = rotator(input_t, boxes)
assert r_img.dtype == tf.float16


Expand Down

0 comments on commit 223abaf

Please sign in to comment.