Skip to content

Commit

Permalink
[transforms] Add RandomResize (like ZoomOut) (#1574)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Apr 30, 2024
1 parent 46d5974 commit 2940d9d
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/modules/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Here are all transformations that are available through docTR:
.. autoclass:: GaussianNoise
.. autoclass:: RandomHorizontalFlip
.. autoclass:: RandomShadow
.. autoclass:: RandomResize


Composing transformations
Expand Down
37 changes: 36 additions & 1 deletion doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..functional.pytorch import random_shadow

__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow"]
__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]


class Resize(T.Resize):
Expand Down Expand Up @@ -213,3 +213,38 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:

def extra_repr(self) -> str:
return f"opacity_range={self.opacity_range}"


class RandomResize(torch.nn.Module):
"""Randomly resize the input image and align corresponding targets
>>> import torch
>>> from doctr.transforms import RandomResize
>>> transfo = RandomResize((0.3, 0.9), p=0.5)
>>> out = transfo(torch.rand((3, 64, 64)))
Args:
----
scale_range: range of the resizing factor for width and height (independently)
p: probability to apply the transformation
"""

def __init__(self, scale_range: Tuple[float, float] = (0.3, 0.9), p: float = 0.5) -> None:
super().__init__()
self.scale_range = scale_range
self.p = p
self._resize = Resize

def forward(self, img: torch.Tensor, target: np.ndarray) -> Tuple[torch.Tensor, np.ndarray]:
if torch.rand(1) < self.p:
scale_h = np.random.uniform(*self.scale_range)
scale_w = np.random.uniform(*self.scale_range)
new_size = (int(img.shape[-2] * scale_h), int(img.shape[-1] * scale_w))

_img, _target = self._resize(new_size, preserve_aspect_ratio=True, symmetric_pad=True)(img, target)

return _img, _target
return img, target

def extra_repr(self) -> str:
return f"scale_range={self.scale_range}, p={self.p}"
36 changes: 36 additions & 0 deletions doctr/transforms/modules/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"GaussianNoise",
"RandomHorizontalFlip",
"RandomShadow",
"RandomResize",
]


Expand Down Expand Up @@ -515,3 +516,38 @@ def __call__(self, x: tf.Tensor) -> tf.Tensor:

def extra_repr(self) -> str:
return f"opacity_range={self.opacity_range}"


class RandomResize(NestedObject):
"""Randomly resize the input image and align corresponding targets
>>> import tensorflow as tf
>>> from doctr.transforms import RandomResize
>>> transfo = RandomResize((0.3, 0.9), p=0.5)
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
Args:
----
scale_range: range of the resizing factor for width and height (independently)
p: probability to apply the transformation
"""

def __init__(self, scale_range: Tuple[float, float] = (0.3, 0.9), p: float = 0.5) -> None:
super().__init__()
self.scale_range = scale_range
self.p = p
self._resize = Resize

def __call__(self, img: tf.Tensor, target: np.ndarray) -> Tuple[tf.Tensor, np.ndarray]:
if np.random.rand(1) <= self.p:
scale_h = random.uniform(*self.scale_range)
scale_w = random.uniform(*self.scale_range)
new_size = (int(img.shape[-3] * scale_h), int(img.shape[-2] * scale_w))

_img, _target = self._resize(new_size, preserve_aspect_ratio=True, symmetric_pad=True)(img, target)

return _img, _target
return img, target

def extra_repr(self) -> str:
return f"scale_range={self.scale_range}, p={self.p}"
24 changes: 24 additions & 0 deletions tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GaussianNoise,
RandomCrop,
RandomHorizontalFlip,
RandomResize,
RandomRotate,
RandomShadow,
Resize,
Expand Down Expand Up @@ -326,3 +327,26 @@ def test_random_shadow(input_dtype, input_shape):
assert torch.all(transformed <= 255)
else:
assert torch.all(transformed <= 1.0)


@pytest.mark.parametrize(
"p,target",
[
[1, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)],
[0, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)],
[1, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)],
[0, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)],
],
)
def test_random_resize(p, target):
transfo = RandomResize(scale_range=(0.3, 1.3), p=p)
assert repr(transfo) == f"RandomResize(scale_range=(0.3, 1.3), p={p})"

img = torch.rand((3, 64, 64))
# Apply the transformation
out_img, out_target = transfo(img, target)
assert isinstance(out_img, torch.Tensor)
assert isinstance(out_target, np.ndarray)
# Resize is already well tested
assert torch.all(out_img == img) if p == 0 else out_img.shape != img.shape
assert out_target.shape == target.shape
23 changes: 23 additions & 0 deletions tests/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,26 @@ def test_random_shadow(input_dtype, input_shape):
assert tf.reduce_all(transformed <= 255)
else:
assert tf.reduce_all(transformed <= 1.0)


@pytest.mark.parametrize(
"p,target",
[
[1, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)],
[0, np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)],
[1, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)],
[0, np.array([[[0.1, 0.8], [0.3, 0.1], [0.3, 0.4], [0.8, 0.4]]], dtype=np.float32)],
],
)
def test_random_resize(p, target):
transfo = T.RandomResize(scale_range=(0.3, 1.3), p=p)
assert repr(transfo) == f"RandomResize(scale_range=(0.3, 1.3), p={p})"

img = tf.random.uniform((64, 64, 3))
# Apply the transformation
out_img, out_target = transfo(img, target)
assert isinstance(out_img, tf.Tensor)
assert isinstance(out_target, np.ndarray)
# Resize is already well-tested
assert tf.reduce_all(tf.equal(out_img, img)) if p == 0 else out_img.shape != img.shape
assert out_target.shape == target.shape

0 comments on commit 2940d9d

Please sign in to comment.