From 9f5d0f802caee7531ea9af48e36eaea4ab17881b Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:25:39 +0100 Subject: [PATCH 1/7] feat: Added support of basic sequence to decode_sequence --- doctr/datasets/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index adb6b54464..514ebfaadf 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -5,8 +5,9 @@ import string import unicodedata +from collections import Sequence from functools import partial -from typing import Any, List, Optional +from typing import Any, List, Optional, Sequence, Union import numpy as np @@ -66,23 +67,23 @@ def encode_string( def decode_sequence( - input_array: np.array, + input_seq: Union[np.array, Sequence[int]], mapping: str, ) -> str: """Given a predefined mapping, decode the sequence of numbers to a string Args: - input_array: array to decode + input_seq: array to decode mapping: vocabulary (string), the encoding is given by the indexing of the character sequence Returns: - A string, decoded from input_array + A string, decoded from input_seq """ - if not input_array.dtype == np.int_ or input_array.max() >= len(mapping): + if not isinstance(input_seq, Sequence) or input_seq.dtype != np.int_ or input_seq.max() >= len(mapping): raise AssertionError("Input must be an array of int, with max less than mapping size") - return ''.join(map(mapping.__getitem__, input_array)) + return ''.join(map(mapping.__getitem__, input_seq)) def encode_sequences( From 3d84d268d1305b6712c1fb0d20c2ebe830a6f268 Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:25:50 +0100 Subject: [PATCH 2/7] test: Updated unittests --- tests/common/test_datasets_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/test_datasets_utils.py b/tests/common/test_datasets_utils.py index 35426e1aa1..aa365f5dc9 100644 --- a/tests/common/test_datasets_utils.py +++ b/tests/common/test_datasets_utils.py @@ -34,7 +34,7 @@ def test_encode_decode(input_str): mapping = """3K}7eé;5àÎYho]QwV6qU~W"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£| za1ù8,OG€P-kçHëÀÂ2É/ûIJ\'j(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l""" encoded = utils.encode_string(input_str, mapping) - decoded = utils.decode_sequence(np.array(encoded), mapping) + decoded = utils.decode_sequence(encoded, mapping) assert decoded == input_str From 4b7b032f8e64df0dcd92c18418592322da892f87 Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:26:27 +0100 Subject: [PATCH 3/7] feat: Speeded up CTC decoding in PyTorch --- doctr/models/recognition/crnn/pytorch.py | 25 ++++++++++-------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index c1e9dedae4..92c1db5e42 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -11,7 +11,7 @@ from torch import nn from torch.nn import functional as F -from ....datasets import VOCABS +from ....datasets import VOCABS, decode_sequence from ...backbones import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn from ...utils import load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor @@ -69,20 +69,15 @@ def ctc_best_path( Returns: A list of tuples: (word, confidence) """ - # compute softmax - probs = F.softmax(logits, dim=-1) - # get char indices along best path - best_path = torch.argmax(probs, dim=-1) - # define word proba as min proba of sequence - probs, _ = torch.max(probs, dim=-1) - probs, _ = torch.min(probs, dim=1) - - words = [] - for sequence in best_path: - # collapse best path (using itertools.groupby), map to chars, join char list to string - collapsed = [vocab[k] for k, _ in groupby(sequence) if k != blank] - res = ''.join(collapsed) - words.append(res) + + # Gather the most confident characters, and assign the smallest conf among those to the sequence prob + probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values + + # collapse best path (using itertools.groupby), map to chars, join char list to string + words = [ + decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) + for seq in torch.argmax(logits, dim=-1) + ] return list(zip(words, probs.tolist())) From a936512414b98e425dac3ed8e1a30bbb4dcf6031 Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:30:32 +0100 Subject: [PATCH 4/7] fix: Fixed decode_sequence --- doctr/datasets/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index 514ebfaadf..84b164c672 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -7,7 +7,9 @@ import unicodedata from collections import Sequence from functools import partial -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Optional +from typing import Sequence as SequenceType +from typing import Union import numpy as np @@ -67,7 +69,7 @@ def encode_string( def decode_sequence( - input_seq: Union[np.array, Sequence[int]], + input_seq: Union[np.array, SequenceType[int]], mapping: str, ) -> str: """Given a predefined mapping, decode the sequence of numbers to a string @@ -80,7 +82,9 @@ def decode_sequence( A string, decoded from input_seq """ - if not isinstance(input_seq, Sequence) or input_seq.dtype != np.int_ or input_seq.max() >= len(mapping): + if not isinstance(input_seq, (Sequence, np.ndarray)): + raise TypeError("Invalid sequence type") + if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)): raise AssertionError("Input must be an array of int, with max less than mapping size") return ''.join(map(mapping.__getitem__, input_seq)) From f00123c44ae3c75389919eac97b7de15850e011b Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:36:20 +0100 Subject: [PATCH 5/7] style: Silenced import warning --- doctr/datasets/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index 84b164c672..66b16b6b2c 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -5,7 +5,7 @@ import string import unicodedata -from collections import Sequence +from collections.abc import Sequence from functools import partial from typing import Any, List, Optional from typing import Sequence as SequenceType From bda3773684c92d71678dca2751255559c9dd93ec Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:36:33 +0100 Subject: [PATCH 6/7] test: Extended decode_sequence unittest --- tests/common/test_datasets_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/common/test_datasets_utils.py b/tests/common/test_datasets_utils.py index aa365f5dc9..e33662808c 100644 --- a/tests/common/test_datasets_utils.py +++ b/tests/common/test_datasets_utils.py @@ -38,6 +38,16 @@ def test_encode_decode(input_str): assert decoded == input_str +def test_decode_sequence(): + mapping = "abcdef" + with pytest.raises(TypeError): + utils.decode_sequence(123, mapping) + with pytest.raises(AssertionError): + utils.decode_sequence(np.array([2, 10]), mapping) + with pytest.raises(AssertionError): + utils.decode_sequence(np.array([2, 4.5]), mapping) + + @pytest.mark.parametrize( "sequences, vocab, target_size, sos, eos, pad, dynamic_len, error, out_shape, gts", [ From bbf5676d3a3abca6b630002a961eb54bbb7f8daa Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Wed, 17 Nov 2021 19:37:38 +0100 Subject: [PATCH 7/7] test: Extended unittest --- tests/common/test_datasets_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/common/test_datasets_utils.py b/tests/common/test_datasets_utils.py index e33662808c..b69c5d1daf 100644 --- a/tests/common/test_datasets_utils.py +++ b/tests/common/test_datasets_utils.py @@ -47,6 +47,8 @@ def test_decode_sequence(): with pytest.raises(AssertionError): utils.decode_sequence(np.array([2, 4.5]), mapping) + assert utils.decode_sequence([3, 4, 3, 4], mapping) == "dede" + @pytest.mark.parametrize( "sequences, vocab, target_size, sos, eos, pad, dynamic_len, error, out_shape, gts",