diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index adb6b54464..66b16b6b2c 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -5,8 +5,11 @@ import string import unicodedata +from collections.abc import Sequence from functools import partial from typing import Any, List, Optional +from typing import Sequence as SequenceType +from typing import Union import numpy as np @@ -66,23 +69,25 @@ def encode_string( def decode_sequence( - input_array: np.array, + input_seq: Union[np.array, SequenceType[int]], mapping: str, ) -> str: """Given a predefined mapping, decode the sequence of numbers to a string Args: - input_array: array to decode + input_seq: array to decode mapping: vocabulary (string), the encoding is given by the indexing of the character sequence Returns: - A string, decoded from input_array + A string, decoded from input_seq """ - if not input_array.dtype == np.int_ or input_array.max() >= len(mapping): + if not isinstance(input_seq, (Sequence, np.ndarray)): + raise TypeError("Invalid sequence type") + if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)): raise AssertionError("Input must be an array of int, with max less than mapping size") - return ''.join(map(mapping.__getitem__, input_array)) + return ''.join(map(mapping.__getitem__, input_seq)) def encode_sequences( diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index c1e9dedae4..92c1db5e42 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -11,7 +11,7 @@ from torch import nn from torch.nn import functional as F -from ....datasets import VOCABS +from ....datasets import VOCABS, decode_sequence from ...backbones import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn from ...utils import load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor @@ -69,20 +69,15 @@ def ctc_best_path( Returns: A list of tuples: (word, confidence) """ - # compute softmax - probs = F.softmax(logits, dim=-1) - # get char indices along best path - best_path = torch.argmax(probs, dim=-1) - # define word proba as min proba of sequence - probs, _ = torch.max(probs, dim=-1) - probs, _ = torch.min(probs, dim=1) - - words = [] - for sequence in best_path: - # collapse best path (using itertools.groupby), map to chars, join char list to string - collapsed = [vocab[k] for k, _ in groupby(sequence) if k != blank] - res = ''.join(collapsed) - words.append(res) + + # Gather the most confident characters, and assign the smallest conf among those to the sequence prob + probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values + + # collapse best path (using itertools.groupby), map to chars, join char list to string + words = [ + decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) + for seq in torch.argmax(logits, dim=-1) + ] return list(zip(words, probs.tolist())) diff --git a/tests/common/test_datasets_utils.py b/tests/common/test_datasets_utils.py index 35426e1aa1..b69c5d1daf 100644 --- a/tests/common/test_datasets_utils.py +++ b/tests/common/test_datasets_utils.py @@ -34,10 +34,22 @@ def test_encode_decode(input_str): mapping = """3K}7eé;5àÎYho]QwV6qU~W"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£| za1ù8,OG€P-kçHëÀÂ2É/ûIJ\'j(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l""" encoded = utils.encode_string(input_str, mapping) - decoded = utils.decode_sequence(np.array(encoded), mapping) + decoded = utils.decode_sequence(encoded, mapping) assert decoded == input_str +def test_decode_sequence(): + mapping = "abcdef" + with pytest.raises(TypeError): + utils.decode_sequence(123, mapping) + with pytest.raises(AssertionError): + utils.decode_sequence(np.array([2, 10]), mapping) + with pytest.raises(AssertionError): + utils.decode_sequence(np.array([2, 4.5]), mapping) + + assert utils.decode_sequence([3, 4, 3, 4], mapping) == "dede" + + @pytest.mark.parametrize( "sequences, vocab, target_size, sos, eos, pad, dynamic_len, error, out_shape, gts", [