Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: ⚡ Speeded up CTC decoding in PyTorch by x10 #633

Merged
merged 7 commits into from
Nov 18, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: Fixed decode_sequence
  • Loading branch information
fg-mindee committed Nov 17, 2021
commit a936512414b98e425dac3ed8e1a30bbb4dcf6031
10 changes: 7 additions & 3 deletions doctr/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import unicodedata
from collections import Sequence
from functools import partial
from typing import Any, List, Optional, Sequence, Union
from typing import Any, List, Optional
from typing import Sequence as SequenceType
from typing import Union

import numpy as np

Expand Down Expand Up @@ -67,7 +69,7 @@ def encode_string(


def decode_sequence(
input_seq: Union[np.array, Sequence[int]],
input_seq: Union[np.array, SequenceType[int]],
mapping: str,
) -> str:
"""Given a predefined mapping, decode the sequence of numbers to a string
Expand All @@ -80,7 +82,9 @@ def decode_sequence(
A string, decoded from input_seq
"""

if not isinstance(input_seq, Sequence) or input_seq.dtype != np.int_ or input_seq.max() >= len(mapping):
if not isinstance(input_seq, (Sequence, np.ndarray)):
raise TypeError("Invalid sequence type")
if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)):
raise AssertionError("Input must be an array of int, with max less than mapping size")

return ''.join(map(mapping.__getitem__, input_seq))
Expand Down