Skip to content

Commit

Permalink
feat: Adds GaussianBlur, random font for CharGenerator and improves t…
Browse files Browse the repository at this point in the history
…raining scripts (#758)

* feat: Added GaussianBlur for TF

* test: Updated unittests

* docs: Added GaussianBlur to the transformations

* feat: Added possibility to pick random font for CharGenerator

* docs: Updated CharGenerator docstring

* feat: Improves characte classification training script

* style: Fixed typing
  • Loading branch information
fg-mindee authored Dec 26, 2021
1 parent e8583f3 commit 56a5830
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 55 deletions.
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Here are all transformations that are available through docTR:
.. autoclass:: RandomJpegQuality
.. autoclass:: RandomRotate
.. autoclass:: RandomCrop
.. autoclass:: GaussianBlur


Composing transformations
Expand Down
67 changes: 50 additions & 17 deletions doctr/datasets/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Any, Callable, List, Optional, Tuple
import random
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image, ImageDraw

Expand All @@ -13,27 +14,48 @@
from ..datasets import AbstractDataset


def synthesize_char_img(char: str, size: int = 32, font_family: Optional[str] = None) -> Image:
def synthesize_text_img(
text: str,
img_size: Optional[Tuple[int, int]] = None,
font_size: Optional[int] = None,
font_family: Optional[str] = None,
background_color: Optional[Tuple[int, int, int]] = None,
text_color: Optional[Tuple[int, int, int]] = None,
text_pos: Optional[Tuple[int, int]] = None,
) -> Image:
"""Generate a synthetic character image with black background and white text
Args:
char: the character to render as an image
size: the size of the rendered image
text: the character to render as an image
img_size: the size of the rendered image
font_size: the size of the font
font_family: the font family (has to be installed on your system)
background_color: background color of the final image
text_color: text color on the final image
text_pos: offset of the text
Returns:
PIL image of the character
"""

if len(char) != 1:
raise AssertionError('expected a single character input')
background_color = (0, 0, 0) if background_color is None else background_color
text_color = (255, 255, 255) if text_color is None else text_color
default_h = 32
if font_size is None:
font_size = int(0.9 * default_h) if img_size is None else int(0.9 * img_size[0])

img = Image.new('RGB', (size, size), color=(0, 0, 0))
font = get_font(font_family, font_size)
text_size = font.getsize(text)
if img_size is None:
img_size = (default_h, text_size[0] if len(text) > 1 else default_h)

img = Image.new('RGB', img_size[::-1], color=background_color)
d = ImageDraw.Draw(img)

# Draw the character
font = get_font(font_family, size)
d.text((4, 0), char, font=font, fill=(255, 255, 255))
if text_pos is None:
text_pos = (0, 0) if text_size[0] >= img_size[1] else (int(round(img_size[0] * 3 / 16)), 0)
d.text(text_pos, text, font=font, fill=text_color)

return img

Expand All @@ -45,30 +67,41 @@ def __init__(
vocab: str,
num_samples: int,
cache_samples: bool = False,
font_family: Optional[Union[str, List[str]]] = None,
img_transforms: Optional[Callable[[Any], Any]] = None,
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
font_family: Optional[str] = None,
) -> None:
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms
self.vocab = vocab
self._num_samples = num_samples
self.font_family = font_family
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
# Validate fonts
if isinstance(font_family, list):
for font in self.font_family:
try:
_ = get_font(font, 10)
except OSError:
raise ValueError(f"unable to locate font: {font}")
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms

self._data: List[Image.Image] = []
if cache_samples:
self._data = [synthesize_char_img(char, font_family=self.font_family) for char in self.vocab]
self._data = [
(synthesize_text_img(char, font_family=font), idx)
for idx, char in enumerate(self.vocab) for font in self.font_family
]

def __len__(self) -> int:
return self._num_samples

def _read_sample(self, index: int) -> Tuple[Any, int]:
target = index % len(self.vocab)
# Samples are already cached
if len(self._data) > 0:
pil_img = self._data[target].copy()
idx = index % len(self._data)
pil_img, target = self._data[idx]
else:
pil_img = synthesize_char_img(self.vocab[target], font_family=self.font_family)
target = index % len(self.vocab)
pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
img = tensor_from_pil(pil_img)

return img, target
4 changes: 3 additions & 1 deletion doctr/datasets/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class CharacterGenerator(_CharacterGenerator):
vocab: vocabulary to take the character from
num_samples: number of samples that will be generated iterating over the dataset
cache_samples: whether generated images should be cached firsthand
sample_transforms: composable transformations that will be applied to each image
font_family: font to use to generate the text images
img_transforms: composable transformations that will be applied to each image
sample_transforms: composable transformations that will be applied to both the image and the target
"""

def __init__(self, *args, **kwargs) -> None:
Expand Down
4 changes: 3 additions & 1 deletion doctr/datasets/classification/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class CharacterGenerator(_CharacterGenerator):
vocab: vocabulary to take the character from
num_samples: number of samples that will be generated iterating over the dataset
cache_samples: whether generated images should be cached firsthand
sample_transforms: composable transformations that will be applied to each image
font_family: font to use to generate the text images
img_transforms: composable transformations that will be applied to each image
sample_transforms: composable transformations that will be applied to both the image and the target
"""

def __init__(self, *args, **kwargs) -> None:
Expand Down
39 changes: 36 additions & 3 deletions doctr/transforms/modules/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import random
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, Iterable, List, Tuple, Union

import tensorflow as tf
import tensorflow_addons as tfa

from doctr.utils.repr import NestedObject

__all__ = ['Compose', 'Resize', 'Normalize', 'LambdaTransformation', 'ToGray', 'RandomBrightness',
'RandomContrast', 'RandomSaturation', 'RandomHue', 'RandomGamma', 'RandomJpegQuality']
'RandomContrast', 'RandomSaturation', 'RandomHue', 'RandomGamma', 'RandomJpegQuality', 'GaussianBlur']


class Compose(NestedObject):
Expand Down Expand Up @@ -141,8 +142,12 @@ class ToGray(NestedObject):
>>> transfo = ToGray()
>>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1))
"""
def __init__(self, num_output_channels: int = 1):
self.num_output_channels = num_output_channels

def __call__(self, img: tf.Tensor) -> tf.Tensor:
return tf.image.rgb_to_grayscale(img)
img = tf.image.rgb_to_grayscale(img)
return img if self.num_output_channels == 1 else tf.repeat(img, self.num_output_channels, axis=-1)


class RandomBrightness(NestedObject):
Expand Down Expand Up @@ -298,3 +303,31 @@ def __call__(self, img: tf.Tensor) -> tf.Tensor:
return tf.image.random_jpeg_quality(
img, min_jpeg_quality=self.min_quality, max_jpeg_quality=self.max_quality
)


class GaussianBlur(NestedObject):
"""Randomly adjust jpeg quality of a 3 dimensional RGB image
Example::
>>> from doctr.transforms import GaussianBlur
>>> import tensorflow as tf
>>> transfo = GaussianBlur(3, (.1, 5))
>>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1))
Args:
kernel_shape: size of the blurring kernel
std: min and max value of the standard deviation
"""
def __init__(self, kernel_shape: Union[int, Iterable[int]], std: Tuple[float, float]) -> None:
self.kernel_shape = kernel_shape
self.std = std

def extra_repr(self) -> str:
return f"kernel_shape={self.kernel_shape}, std={self.std}"

@tf.function
def __call__(self, img: tf.Tensor) -> tf.Tensor:
sigma = random.uniform(self.std[0], self.std[1])
return tfa.image.gaussian_filter2d(
img, filter_shape=self.kernel_shape, sigma=sigma,
)
33 changes: 24 additions & 9 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from torch.nn.functional import cross_entropy
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision import models
from torchvision.transforms import ColorJitter, Compose, Normalize, RandomPerspective
from torchvision.transforms import (ColorJitter, Compose, GaussianBlur, Grayscale, InterpolationMode, Normalize,
RandomRotation)

from doctr import transforms as T
from doctr.datasets import VOCABS, CharacterGenerator
from doctr.models import classification
from utils import plot_recorder, plot_samples


Expand Down Expand Up @@ -175,14 +176,20 @@ def main(args):

vocab = VOCABS[args.vocab]

fonts = args.font.split(",")

# Load val data generator
st = time.time()
val_set = CharacterGenerator(
vocab=vocab,
num_samples=args.val_samples * len(vocab),
cache_samples=True,
img_transforms=T.Resize((args.input_size, args.input_size)),
font_family=args.font,
img_transforms=Compose([
T.Resize((args.input_size, args.input_size)),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(), .9),
]),
font_family=fonts,
)
val_loader = DataLoader(
val_set,
Expand All @@ -198,7 +205,7 @@ def main(args):
batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

# Load doctr model
model = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=len(vocab))
model = classification.__dict__[args.arch](pretrained=args.pretrained, num_classes=len(vocab))

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -237,11 +244,14 @@ def main(args):
img_transforms=Compose([
T.Resize((args.input_size, args.input_size)),
# Augmentations
RandomPerspective(),
T.RandomApply(T.ColorInversion(), .7),
T.RandomApply(T.ColorInversion(), .9),
# GaussianNoise
T.RandomApply(Grayscale(3), .1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
T.RandomApply(GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3)), .3),
RandomRotation(15, interpolation=InterpolationMode.BILINEAR),
]),
font_family=args.font,
font_family=fonts,
)

train_loader = DataLoader(
Expand Down Expand Up @@ -340,7 +350,12 @@ def parse_args():
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used')
parser.add_argument(
'--font',
type=str,
default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf",
help='Font family to be used'
)
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training')
parser.add_argument(
'--train-samples',
Expand Down
26 changes: 20 additions & 6 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def main(args):

vocab = VOCABS[args.vocab]

fonts = args.font.split(",")

# AMP
if args.amp:
mixed_precision.set_global_policy('mixed_float16')
Expand All @@ -146,8 +148,12 @@ def main(args):
vocab=vocab,
num_samples=args.val_samples * len(vocab),
cache_samples=True,
img_transforms=T.Resize((args.input_size, args.input_size)),
font_family=args.font,
img_transforms=T.Compose([
T.Resize((args.input_size, args.input_size)),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(), .9),
]),
font_family=fonts,
)
val_loader = DataLoader(
val_set,
Expand Down Expand Up @@ -192,13 +198,16 @@ def main(args):
img_transforms=T.Compose([
T.Resize((args.input_size, args.input_size)),
# Augmentations
T.RandomApply(T.ColorInversion(), .7),
T.RandomApply(T.ColorInversion(), .9),
T.RandomApply(T.ToGray(3), .1),
T.RandomJpegQuality(60),
T.RandomSaturation(.3),
T.RandomContrast(.3),
T.RandomBrightness(.3),
# Blur
T.RandomApply(T.GaussianBlur(kernel_shape=(3, 3), std=(0.1, 3)), .3),
]),
font_family=args.font,
font_family=fonts,
)
train_loader = DataLoader(
train_set,
Expand All @@ -220,7 +229,7 @@ def main(args):
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
decay_rate=1 / (25e4), # final lr as a fraction of initial lr
decay_rate=1 / (1e3), # final lr as a fraction of initial lr
staircase=False
)
optimizer = tf.keras.optimizers.Adam(
Expand Down Expand Up @@ -302,7 +311,12 @@ def parse_args():
parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used')
parser.add_argument(
'--font',
type=str,
default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf",
help='Font family to be used'
)
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training')
parser.add_argument(
'--train-samples',
Expand Down
27 changes: 9 additions & 18 deletions tests/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,24 +230,6 @@ def test_jpegquality():
assert out.dtype == tf.float16


def test_oneof():
transfos = [
T.RandomGamma(min_gamma=1., max_gamma=2., min_gain=.8, max_gain=1.),
T.RandomContrast(delta=.2)
]
input_t = tf.cast(tf.fill([8, 32, 32, 3], 2.), dtype=tf.float32)
out = T.OneOf(transfos)(input_t)
assert ((tf.reduce_all(out >= 1.6) and tf.reduce_all(out <= 4.)) or tf.reduce_all(out == 2.))


def test_randomapply():

transfo = T.RandomGamma(min_gamma=1., max_gamma=2., min_gain=.8, max_gain=1.)
input_t = tf.cast(tf.fill([8, 32, 32, 3], 2.), dtype=tf.float32)
out = T.RandomApply(transfo, p=1.)(input_t)
assert (tf.reduce_all(out >= 1.6) and tf.reduce_all(out <= 4.))


def test_rotate():
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
Expand Down Expand Up @@ -329,3 +311,12 @@ def test_random_crop():
new_h, new_w = c_img.shape[:2]
assert new_h >= 3
assert new_w >= 3


def test_gaussian_blur():
blur = T.GaussianBlur(3, (.1, 3))
input_t = np.ones((31, 31, 3), dtype=np.float32)
input_t[15, 15] = 0
blur_img = blur(tf.convert_to_tensor(input_t)).numpy()
assert blur_img.shape == input_t.shape
assert np.all(blur_img[15, 15] > 0)

0 comments on commit 56a5830

Please sign in to comment.