diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py index d6b86fa8e1..056e61dc68 100644 --- a/doctr/datasets/cord.py +++ b/doctr/datasets/cord.py @@ -58,6 +58,7 @@ def __init__( # File existence check if not os.path.exists(os.path.join(tmp_root, img_path)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + stem = Path(img_path).stem _targets = [] with open(os.path.join(self.root, 'json', f"{stem}.json"), 'rb') as f: diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index b7f15a0899..1a45acf174 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -49,6 +49,10 @@ def __init__( self.data: List[Tuple[str, np.ndarray]] = [] for img_name, label in labels.items(): + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + polygons = np.asarray(label['polygons']) if rotated_bbox: # Switch to rotated rects diff --git a/doctr/datasets/doc_artefacts.py b/doctr/datasets/doc_artefacts.py index cbfccd3d28..2aff011f71 100644 --- a/doctr/datasets/doc_artefacts.py +++ b/doctr/datasets/doc_artefacts.py @@ -61,6 +61,7 @@ def __init__( # File existence check if not os.path.exists(os.path.join(tmp_root, img_name)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}") + boxes = np.asarray([obj['geometry'] for obj in label], dtype=np_dtype) classes = np.asarray([self.CLASSES.index(obj['label']) for obj in label], dtype=np.long) if rotated_bbox: diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py index 058d1e7cff..20687ec198 100644 --- a/doctr/datasets/funsd.py +++ b/doctr/datasets/funsd.py @@ -57,6 +57,7 @@ def __init__( # File existence check if not os.path.exists(os.path.join(tmp_root, img_path)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + stem = Path(img_path).stem with open(os.path.join(self.root, subfolder, 'annotations', f"{stem}.json"), 'rb') as f: data = json.load(f) diff --git a/doctr/datasets/recognition.py b/doctr/datasets/recognition.py index ea2ddf4de5..692bfc962a 100644 --- a/doctr/datasets/recognition.py +++ b/doctr/datasets/recognition.py @@ -39,14 +39,12 @@ def __init__( self.data: List[Tuple[str, str]] = [] with open(labels_path) as f: labels = json.load(f) - for img_path in os.listdir(self.root): - # File existence check - if not os.path.exists(os.path.join(self.root, img_path)): - raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_path)}") - label = labels.get(img_path) - if not isinstance(label, str): - raise KeyError("Image is not in referenced in label file") - self.data.append((img_path, label)) + + for img_name, label in labels.items(): + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + self.data.append((img_name, label)) def merge_dataset(self, ds: AbstractDataset) -> None: # Update data with new root for self diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py index 54c5cb6f58..8ab9d899f0 100644 --- a/doctr/datasets/sroie.py +++ b/doctr/datasets/sroie.py @@ -60,6 +60,7 @@ def __init__( # File existence check if not os.path.exists(os.path.join(tmp_root, img_path)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + stem = Path(img_path).stem _targets = [] with open(os.path.join(self.root, 'annotations', f"{stem}.txt"), encoding='latin') as f: diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py new file mode 100644 index 0000000000..6a3fee30ac --- /dev/null +++ b/doctr/models/detection/_utils/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/detection/_utils/pytorch.py b/doctr/models/detection/_utils/pytorch.py new file mode 100644 index 0000000000..456f9dfaa7 --- /dev/null +++ b/doctr/models/detection/_utils/pytorch.py @@ -0,0 +1,37 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from torch import Tensor +from torch.nn.functional import max_pool2d + +__all__ = ['erode', 'dilate'] + + +def erode(x: Tensor, kernel_size: int) -> Tensor: + """Performs erosion on a given tensor + + Args: + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for erosion + Returns: + the eroded tensor + """ + _pad = (kernel_size - 1) // 2 + + return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad) + + +def dilate(x: Tensor, kernel_size: int) -> Tensor: + """Performs dilation on a given tensor + + Args: + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for dilation + Returns: + the dilated tensor + """ + _pad = (kernel_size - 1) // 2 + + return max_pool2d(x, kernel_size, stride=1, padding=_pad) diff --git a/doctr/models/detection/_utils/tensorflow.py b/doctr/models/detection/_utils/tensorflow.py new file mode 100644 index 0000000000..0cadfd7eaa --- /dev/null +++ b/doctr/models/detection/_utils/tensorflow.py @@ -0,0 +1,34 @@ +# Copyright (C) 2021, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +__all__ = ['erode', 'dilate'] + + +def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs erosion on a given tensor + + Args: + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for erosion + Returns: + the eroded tensor + """ + + return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME") + + +def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs dilation on a given tensor + + Args: + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for dilation + Returns: + the dilated tensor + """ + + return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME") diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 4e6b6959ca..0228306103 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -110,7 +110,7 @@ def bitmap_to_boxes( else: score = self.box_score(pred, contour, assume_straight_pages=False) - if self.box_thresh > score: # remove polygons with a weak objectness + if score < self.box_thresh: # remove polygons with a weak objectness continue if self.assume_straight_pages: @@ -253,7 +253,7 @@ def draw_thresh_map( return polygon, canvas, mask - def compute_target( + def build_target( self, target: List[np.ndarray], output_shape: Tuple[int, int, int], diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 1a9f60cb66..e363d3de0e 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -137,7 +137,7 @@ def __init__( nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False), nn.BatchNorm2d(head_chans // 4), nn.ReLU(inplace=True), - nn.ConvTranspose2d(head_chans // 4, 1, 2, stride=2), + nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), ) self.thresh_head = nn.Sequential( conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False), @@ -213,7 +213,7 @@ def compute_loss( prob_map = torch.sigmoid(out_map.squeeze(1)) thresh_map = torch.sigmoid(thresh_map.squeeze(1)) - targets = self.compute_target(target, prob_map.shape) # type: ignore[arg-type] + targets = self.build_target(target, prob_map.shape) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 118f263add..8a657e7cb4 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -113,6 +113,7 @@ class DBNet(_DBNet, keras.Model, NestedObject): Args: feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to + num_classes: number of output channels in the segmentation map assume_straight_pages: if True, fit straight bounding boxes only cfg: the configuration dict of the model """ @@ -123,6 +124,7 @@ def __init__( self, feature_extractor: IntermediateLayerGetter, fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea + num_classes: int = 1, assume_straight_pages: bool = True, cfg: Optional[Dict[str, Any]] = None, ) -> None: @@ -144,7 +146,7 @@ def __init__( layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'), layers.BatchNormalization(), layers.Activation('relu'), - layers.Conv2DTranspose(1, 2, strides=2, kernel_initializer='he_normal'), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer='he_normal'), ] ) self.threshold_head = keras.Sequential( @@ -153,7 +155,7 @@ def __init__( layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'), layers.BatchNormalization(), layers.Activation('relu'), - layers.Conv2DTranspose(1, 2, strides=2, kernel_initializer='he_normal'), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer='he_normal'), ] ) @@ -180,7 +182,7 @@ def compute_loss( prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1])) thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1])) - seg_target, seg_mask, thresh_target, thresh_mask = self.compute_target(target, out_map.shape[:3]) + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[:3]) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index 7cef94fc56..e3e7b7bee4 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -10,6 +10,7 @@ import cv2 import numpy as np +from doctr.file_utils import is_tf_available from doctr.models.core import BaseModel from doctr.utils.geometry import fit_rbbox, rbbox_to_polygon @@ -71,7 +72,7 @@ def bitmap_to_boxes( else: score = self.box_score(pred, contour, assume_straight_pages=False) - if self.box_thresh > score: # remove polygons with a weak objectness + if score < self.box_thresh: # remove polygons with a weak objectness continue if self.assume_straight_pages: @@ -105,10 +106,10 @@ class _LinkNet(BaseModel): min_size_box: int = 3 assume_straight_pages: bool = True - def compute_target( + def build_target( self, target: List[np.ndarray], - output_shape: Tuple[int, int, int], + output_shape: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if any(t.dtype != np.float32 for t in target): @@ -116,13 +117,16 @@ def compute_target( if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target): raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + h, w = output_shape + target_shape = (len(target), h, w, 1) + if self.assume_straight_pages: - seg_target = np.zeros(output_shape, dtype=bool) - edge_mask = np.zeros(output_shape, dtype=bool) + seg_target = np.zeros(target_shape, dtype=bool) + edge_mask = np.zeros(target_shape, dtype=bool) else: - seg_target = np.zeros(output_shape, dtype=np.uint8) + seg_target = np.zeros(target_shape, dtype=np.uint8) - seg_mask = np.ones(output_shape, dtype=bool) + seg_mask = np.ones(target_shape, dtype=bool) for idx, _target in enumerate(target): # Draw each polygon on gt @@ -132,8 +136,8 @@ def compute_target( # Absolute bounding boxes abs_boxes = _target.copy() - abs_boxes[:, [0, 2]] *= output_shape[-1] - abs_boxes[:, [1, 3]] *= output_shape[-2] + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h abs_boxes = abs_boxes.round().astype(np.int32) if abs_boxes.shape[1] == 5: @@ -155,11 +159,19 @@ def compute_target( cv2.fillPoly(seg_target[idx], [poly.astype(np.int32)], 1) else: seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True - # fill the 2 vertical edges - edge_mask[idx, max(0, box[1] - 1): min(box[1] + 1, box[3]), box[0]: box[2] + 1] = True - edge_mask[idx, max(box[1] + 1, box[3]): min(output_shape[1], box[3] + 2), box[0]: box[2] + 1] = True - # fill the 2 horizontal edges - edge_mask[idx, box[1]: box[3] + 1, max(0, box[0] - 1): min(box[0] + 1, box[2])] = True - edge_mask[idx, box[1]: box[3] + 1, max(box[0] + 1, box[2]): min(output_shape[2], box[2] + 2)] = True + # top edge + edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True + # bot edge + edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True + # left edge + edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True + # right edge + edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True + + # Don't forget to switch back to channel first if PyTorch is used + if not is_tf_available(): + seg_target = seg_target.transpose(0, 3, 1, 2) + seg_mask = seg_mask.transpose(0, 3, 1, 2) + edge_mask = edge_mask.transpose(0, 3, 1, 2) return seg_target, seg_mask, edge_mask diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index 3d7d7e5eaf..0b9e26c0c3 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -166,7 +166,6 @@ def forward( target: Optional[List[np.ndarray]] = None, return_model_output: bool = False, return_boxes: bool = False, - focal_loss: bool = True, **kwargs: Any, ) -> Dict[str, Any]: @@ -185,7 +184,7 @@ def forward( out["preds"] = self.postprocessor(prob_map.squeeze(1).detach().cpu().numpy()) if target is not None: - loss = self.compute_loss(logits, target, focal_loss) + loss = self.compute_loss(logits, target) out['loss'] = loss return out @@ -194,53 +193,34 @@ def compute_loss( self, out_map: torch.Tensor, target: List[np.ndarray], - focal_loss: bool = False, - alpha: float = .5, - gamma: float = 2., edge_factor: float = 2., ) -> torch.Tensor: """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on `_. Args: - out_map: output feature map of the model of shape N x H x W x 1 + out_map: output feature map of the model of shape (N, 1, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry - focal_loss: if True, use focal loss instead of BCE edge_factor: boost factor for box edges (in case of BCE) - alpha: balancing factor in the focal loss formula - gamma: modulating factor in the focal loss formula Returns: A loss tensor """ - targets = self.compute_target(target, out_map.shape) # type: ignore[arg-type] + seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type] - seg_target, seg_mask = torch.from_numpy(targets[0]).to(dtype=out_map.dtype), torch.from_numpy(targets[1]) + seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) - edge_mask = torch.from_numpy(targets[2]).to(out_map.device) + if edge_factor > 0: + edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device) # Get the cross_entropy for each entry - bce = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')[seg_mask] + loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none') - if focal_loss: - if gamma and gamma < 0: - raise ValueError("Value of gamma should be greater than or equal to zero.") - - # Convert logits to prob, compute gamma factor - pred_prob = torch.sigmoid(out_map)[seg_mask] - p_t = (seg_target[seg_mask] * pred_prob) + ((1 - seg_target[seg_mask]) * (1 - pred_prob)) - - # Compute alpha factor - alpha_factor = seg_target[seg_mask] * alpha + (1 - seg_target[seg_mask]) * (1 - alpha) - - # compute the final loss - loss = (alpha_factor * (1. - p_t) ** gamma * bce).mean() - - else: - # Compute BCE loss with highlighted edges - loss = ((1 + (edge_factor - 1) * edge_mask) * bce).mean() - - return loss + # Compute BCE loss with highlighted edges + if edge_factor > 0: + loss = ((1 + (edge_factor - 1) * edge_mask) * loss) + # Only consider contributions overlaping the mask + return loss[seg_mask].mean() def _linknet(arch: str, pretrained: bool, pretrained_backbone: bool = False, **kwargs: Any) -> LinkNet: diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 3ca24a7fd7..e7c5d07fa9 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -139,9 +139,6 @@ def compute_loss( self, out_map: tf.Tensor, target: List[np.ndarray], - focal_loss: bool = False, - alpha: float = .5, - gamma: float = 2., edge_factor: float = 2., ) -> tf.Tensor: """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on @@ -150,49 +147,29 @@ def compute_loss( Args: out_map: output feature map of the model of shape N x H x W x 1 target: list of dictionary where each dict has a `boxes` and a `flags` entry - focal_loss: if True, use focal loss instead of BCE edge_factor: boost factor for box edges (in case of BCE) - alpha: balancing factor in the focal loss formula - gammma: modulating factor in the focal loss formula Returns: A loss tensor """ - seg_target, seg_mask, edge_mask = self.compute_target(target, out_map.shape[:3]) + seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[1:3]) + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) - edge_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + if edge_factor > 0: + edge_mask = tf.convert_to_tensor(edge_mask, dtype=tf.bool) # Get the cross_entropy for each entry - bce = tf.keras.losses.binary_crossentropy( - seg_target[seg_mask], - tf.squeeze(out_map, axis=[-1])[seg_mask], - from_logits=True) - - if focal_loss: - if gamma and gamma < 0: - raise ValueError("Value of gamma should be greater than or equal to zero.") - - # Convert logits to prob, compute gamma factor - pred_prob = tf.sigmoid(tf.squeeze(out_map, axis=[-1])[seg_mask]) - p_t = (seg_target[seg_mask] * pred_prob) + ((1 - seg_target[seg_mask]) * (1 - pred_prob)) - modulating_factor = tf.pow((1.0 - p_t), gamma) - - # Compute alpha factor - alpha_factor = seg_target[seg_mask] * alpha + (1 - seg_target[seg_mask]) * (1 - alpha) - - # compute the final loss - loss = tf.reduce_mean(alpha_factor * modulating_factor * bce) + loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None] - else: - # Compute BCE loss with highlighted edges + # Compute BCE loss with highlighted edges + if edge_factor > 0: loss = tf.math.multiply( 1 + (edge_factor - 1) * tf.cast(edge_mask, out_map.dtype), - bce + loss ) - loss = tf.reduce_mean(loss) - return loss + return tf.reduce_mean(loss[seg_mask]) def call( self, @@ -200,7 +177,6 @@ def call( target: Optional[List[np.ndarray]] = None, return_model_output: bool = False, return_boxes: bool = False, - focal_loss: bool = True, **kwargs: Any, ) -> Dict[str, Any]: @@ -219,7 +195,7 @@ def call( out["preds"] = self.postprocessor(tf.squeeze(prob_map, axis=-1).numpy()) if target is not None: - loss = self.compute_loss(logits, target, focal_loss) + loss = self.compute_loss(logits, target) out['loss'] = loss return out diff --git a/doctr/models/recognition/core.py b/doctr/models/recognition/core.py index f89def2b5b..ff87478129 100644 --- a/doctr/models/recognition/core.py +++ b/doctr/models/recognition/core.py @@ -19,7 +19,7 @@ class RecognitionModel(NestedObject): vocab: str max_length: int - def compute_target( + def build_target( self, gts: List[str], ) -> Tuple[np.ndarray, List[int]]: diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index 92c1db5e42..9bdbac759b 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -160,7 +160,7 @@ def compute_loss( Returns: The loss of the model on the batch """ - gt, seq_len = self.compute_target(target) + gt, seq_len = self.build_target(target) batch_len = model_output.shape[0] input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32) # N x T x C -> T x N x C diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 0415f251c7..5c9fb940a8 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -147,7 +147,7 @@ def compute_loss( Returns: The loss of the model on the batch """ - gt, seq_len = self.compute_target(target) + gt, seq_len = self.build_target(target) batch_len = model_output.shape[0] input_length = tf.fill((batch_len,), model_output.shape[1]) ctc_loss = tf.nn.ctc_loss( diff --git a/doctr/models/recognition/master/base.py b/doctr/models/recognition/master/base.py index 0e84b83e3b..6134384f3f 100644 --- a/doctr/models/recognition/master/base.py +++ b/doctr/models/recognition/master/base.py @@ -16,7 +16,7 @@ class _MASTER: vocab: str max_length: int - def compute_target( + def build_target( self, gts: List[str], ) -> Tuple[np.ndarray, List[int]]: diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py index 24bff799d4..2ed4b0f6dc 100644 --- a/doctr/models/recognition/master/pytorch.py +++ b/doctr/models/recognition/master/pytorch.py @@ -278,7 +278,7 @@ def forward( if target is not None: # Compute target: tensor of gts and sequence lengths - _gt, _seq_len = self.compute_target(target) + _gt, _seq_len = self.build_target(target) gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) gt, seq_len = gt.to(x.device), seq_len.to(x.device) diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 2007b626f6..f63b608631 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -292,7 +292,7 @@ def call( if target is not None: # Compute target: tensor of gts and sequence lengths - gt, seq_len = self.compute_target(target) + gt, seq_len = self.build_target(target) if kwargs.get('training', False): if target is None: diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index b4ac4a79f1..9e5638c834 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -188,7 +188,7 @@ def forward( _, (encoded, _) = self.encoder(pooled_features) encoded = encoded[-1] if target is not None: - _gt, _seq_len = self.compute_target(target) + _gt, _seq_len = self.build_target(target) gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) # type: ignore[assignment] gt, seq_len = gt.to(x.device), seq_len.to(x.device) decoded_features = self.decoder(features, encoded, gt=None if target is None else gt) diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index f4e3a51fa9..b6cf440384 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -258,7 +258,7 @@ def call( pooled_features = tf.reduce_max(features, axis=1) # vertical max pooling encoded = self.encoder(pooled_features, **kwargs) if target is not None: - gt, seq_len = self.compute_target(target) + gt, seq_len = self.build_target(target) seq_len = tf.cast(seq_len, tf.int32) decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index f9d3d39569..b50690f62f 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -1,3 +1,6 @@ +import os +from shutil import move + import numpy as np import pytest import torch @@ -103,6 +106,13 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): _, r_target = rotated_ds[0] assert r_target.shape[1] == 5 + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.DetectionDataset(mock_image_folder, mock_detection_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + def test_recognition_dataset(mock_image_folder, mock_recognition_label): input_size = (32, 128) @@ -123,6 +133,13 @@ def test_recognition_dataset(mock_image_folder, mock_recognition_label): assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.RecognitionDataset(mock_image_folder, mock_recognition_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + def test_ocrdataset(mock_ocrdataset): @@ -152,6 +169,13 @@ def test_ocrdataset(mock_ocrdataset): assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets) + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.OCRDataset(*mock_ocrdataset) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + def test_charactergenerator(): diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 13fa79f980..66dd2507de 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -3,6 +3,7 @@ import torch from doctr.models import detection +from doctr.models.detection._utils import dilate, erode from doctr.models.detection.predictor import DetectionPredictor @@ -67,3 +68,19 @@ def test_detection_zoo(arch_name): with torch.no_grad(): out = predictor(input_tensor) assert all(isinstance(boxes, np.ndarray) and boxes.shape[1] == 5 for boxes in out) + + +def test_erode(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.zeros((1, 1, 3, 3)) + out = erode(x, 3) + assert torch.equal(out, expected) + + +def test_dilate(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.ones((1, 1, 3, 3)) + out = dilate(x, 3) + assert torch.equal(out, expected) diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index ee42a76d04..f5034b7204 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -1,3 +1,6 @@ +import os +from shutil import move + import numpy as np import pytest import tensorflow as tf @@ -90,6 +93,13 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): _, r_target = rotated_ds[0] assert r_target.shape[1] == 5 + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.DetectionDataset(mock_image_folder, mock_detection_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + def test_recognition_dataset(mock_image_folder, mock_recognition_label): input_size = (32, 128) @@ -110,6 +120,13 @@ def test_recognition_dataset(mock_image_folder, mock_recognition_label): assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.RecognitionDataset(mock_image_folder, mock_recognition_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + def test_ocrdataset(mock_ocrdataset): @@ -138,6 +155,13 @@ def test_ocrdataset(mock_ocrdataset): assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets) + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.OCRDataset(*mock_ocrdataset) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + def test_charactergenerator(): diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 6f475885b2..69df2d25f8 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -4,6 +4,7 @@ from doctr.io import DocumentFile from doctr.models import detection +from doctr.models.detection._utils import dilate, erode from doctr.models.detection.predictor import DetectionPredictor from doctr.models.preprocessor import PreProcessor @@ -139,3 +140,21 @@ def test_linknet_focal_loss(): # test focal loss out = model(input_tensor, target, return_model_output=True, return_boxes=True, training=True, focal_loss=True) assert isinstance(out['loss'], tf.Tensor) + + +def test_erode(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.zeros((1, 3, 3, 1)) + out = erode(x, 3) + assert tf.math.reduce_all(out == expected) + + +def test_dilate(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.ones((1, 3, 3, 1)) + out = dilate(x, 3) + assert tf.math.reduce_all(out == expected)