Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: ⚡ Speeded up CTC decoding in PyTorch by x10 #633

Merged
merged 7 commits into from
Nov 18, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Added support of basic sequence to decode_sequence
  • Loading branch information
fg-mindee committed Nov 17, 2021
commit 9f5d0f802caee7531ea9af48e36eaea4ab17881b
13 changes: 7 additions & 6 deletions doctr/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import string
import unicodedata
from collections import Sequence
from functools import partial
from typing import Any, List, Optional
from typing import Any, List, Optional, Sequence, Union

import numpy as np

Expand Down Expand Up @@ -66,23 +67,23 @@ def encode_string(


def decode_sequence(
input_array: np.array,
input_seq: Union[np.array, Sequence[int]],
mapping: str,
) -> str:
"""Given a predefined mapping, decode the sequence of numbers to a string

Args:
input_array: array to decode
input_seq: array to decode
mapping: vocabulary (string), the encoding is given by the indexing of the character sequence

Returns:
A string, decoded from input_array
A string, decoded from input_seq
"""

if not input_array.dtype == np.int_ or input_array.max() >= len(mapping):
if not isinstance(input_seq, Sequence) or input_seq.dtype != np.int_ or input_seq.max() >= len(mapping):
raise AssertionError("Input must be an array of int, with max less than mapping size")

return ''.join(map(mapping.__getitem__, input_array))
return ''.join(map(mapping.__getitem__, input_seq))


def encode_sequences(
Expand Down