Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace unidecode with text-unidecode. #1509

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Replace text-unidecode with anyascii.
  • Loading branch information
jonatankawalek committed Mar 12, 2024
commit 305241ffccda64037be9675ef648f32f68217443
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ requirements:
- weasyprint >=55.0
- defusedxml >=0.7.0
- mplcursors >=0.3
- text-unidecode >=1.3
- anyascii >=0.3.2
- tqdm >=4.30.0

test:
Expand Down
32 changes: 16 additions & 16 deletions doctr/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from typing import Dict, List, Optional, Tuple

from anyascii import anyascii
import cv2
import numpy as np
from scipy.optimize import linear_sum_assignment
from text_unidecode import unidecode

__all__ = [
"TextMatch",
Expand All @@ -34,16 +34,16 @@ def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]:
Returns:
-------
a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their
unidecode counterparts and their lower-case unidecode counterparts match
anyascii counterparts and their lower-case anyascii counterparts match
"""
raw_match = word1 == word2
caseless_match = word1.lower() == word2.lower()
unidecode_match = unidecode(word1) == unidecode(word2)
anyascii_match = anyascii(word1) == anyascii(word2)

# Warning: the order is important here otherwise the pair ("EUR", "€") cannot be matched
unicase_match = unidecode(word1).lower() == unidecode(word2).lower()
unicase_match = anyascii(word1).lower() == anyascii(word2).lower()

return raw_match, caseless_match, unidecode_match, unicase_match
return raw_match, caseless_match, anyascii_match, unicase_match


class TextMatch:
Expand Down Expand Up @@ -94,10 +94,10 @@ def update(
raise AssertionError("prediction size does not match with ground-truth labels size")

for gt_word, pred_word in zip(gt, pred):
_raw, _caseless, _unidecode, _unicase = string_match(gt_word, pred_word)
_raw, _caseless, _anyascii, _unicase = string_match(gt_word, pred_word)
self.raw += int(_raw)
self.caseless += int(_caseless)
self.unidecode += int(_unidecode)
self.anyascii += int(_anyascii)
self.unicase += int(_unicase)

self.total += len(gt)
Expand All @@ -107,23 +107,23 @@ def summary(self) -> Dict[str, float]:

Returns
-------
a dictionary with the exact match score for the raw data, its lower-case counterpart, its unidecode
counterpart and its lower-case unidecode counterpart
a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii
counterpart and its lower-case anyascii counterpart
"""
if self.total == 0:
raise AssertionError("you need to update the metric before getting the summary")

return dict(
raw=self.raw / self.total,
caseless=self.caseless / self.total,
unidecode=self.unidecode / self.total,
anyascii=self.anyascii / self.total,
unicase=self.unicase / self.total,
)

def reset(self) -> None:
self.raw = 0
self.caseless = 0
self.unidecode = 0
self.anyascii = 0
self.unicase = 0
self.total = 0

Expand Down Expand Up @@ -544,10 +544,10 @@ def update(
is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh
# String comparison
for gt_idx, pred_idx in zip(gt_indices[is_kept], pred_indices[is_kept]):
_raw, _caseless, _unidecode, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx])
_raw, _caseless, _anyascii, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx])
self.raw_matches += int(_raw)
self.caseless_matches += int(_caseless)
self.unidecode_matches += int(_unidecode)
self.anyascii_matches += int(_anyascii)
self.unicase_matches += int(_unicase)

self.num_gts += gt_boxes.shape[0]
Expand All @@ -564,15 +564,15 @@ def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]
recall = dict(
raw=self.raw_matches / self.num_gts if self.num_gts > 0 else None,
caseless=self.caseless_matches / self.num_gts if self.num_gts > 0 else None,
unidecode=self.unidecode_matches / self.num_gts if self.num_gts > 0 else None,
anyascii=self.anyascii_matches / self.num_gts if self.num_gts > 0 else None,
unicase=self.unicase_matches / self.num_gts if self.num_gts > 0 else None,
)

# Precision
precision = dict(
raw=self.raw_matches / self.num_preds if self.num_preds > 0 else None,
caseless=self.caseless_matches / self.num_preds if self.num_preds > 0 else None,
unidecode=self.unidecode_matches / self.num_preds if self.num_preds > 0 else None,
anyascii=self.anyascii_matches / self.num_preds if self.num_preds > 0 else None,
unicase=self.unicase_matches / self.num_preds if self.num_preds > 0 else None,
)

Expand All @@ -587,7 +587,7 @@ def reset(self) -> None:
self.tot_iou = 0.0
self.raw_matches = 0
self.caseless_matches = 0
self.unidecode_matches = 0
self.anyascii_matches = 0
self.unicase_matches = 0


Expand Down
10 changes: 5 additions & 5 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union

from anyascii import anyascii
import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import mplcursors
import numpy as np
from matplotlib.figure import Figure
from PIL import Image, ImageDraw
from text_unidecode import unidecode

from .common_types import BoundingBox, Polygon4P
from .fonts import get_font
Expand Down Expand Up @@ -327,8 +327,8 @@
try:
d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))
except UnicodeEncodeError:
# When character cannot be encoded, use its unidecode version
d.text((0, 0), unidecode(word["value"]), font=font, fill=(0, 0, 0))
# When character cannot be encoded, use its anyascii version
d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0))

Check warning on line 331 in doctr/utils/visualization.py

View check run for this annotation

Codecov / codecov/patch

doctr/utils/visualization.py#L331

Added line #L331 was not covered by tests

# Colorize if draw_proba
if draw_proba:
Expand Down Expand Up @@ -458,8 +458,8 @@
try:
d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0))
except UnicodeEncodeError:
# When character cannot be encoded, use its unidecode version
d.text((0, 0), unidecode(prediction["value"]), font=font, fill=(0, 0, 0))
# When character cannot be encoded, use its anyascii version
d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0))

Check warning on line 462 in doctr/utils/visualization.py

View check run for this annotation

Codecov / codecov/patch

doctr/utils/visualization.py#L462

Added line #L462 was not covered by tests

# Colorize if draw_proba
if draw_proba:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies = [
"Pillow>=9.2.0",
"defusedxml>=0.7.0",
"mplcursors>=0.3",
"text-unidecode>=1.3",
"anyascii>=0.3.2",
"tqdm>=4.30.0",
]

Expand Down
26 changes: 13 additions & 13 deletions tests/common/test_utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


@pytest.mark.parametrize(
"gt, pred, raw, caseless, unidecode, unicase",
"gt, pred, raw, caseless, anyascii, unicase",
[
[["grass", "56", "True", "EUR"], ["grass", "56", "true", "€"], 0.5, 0.75, 0.75, 1],
[["éléphant", "ça"], ["elephant", "ca"], 0, 0, 1, 1],
],
)
def test_text_match(gt, pred, raw, caseless, unidecode, unicase):
def test_text_match(gt, pred, raw, caseless, anyascii, unicase):
metric = metrics.TextMatch()
with pytest.raises(AssertionError):
metric.summary()
Expand All @@ -20,10 +20,10 @@ def test_text_match(gt, pred, raw, caseless, unidecode, unicase):
metric.update(["a", "b"], ["c"])

metric.update(gt, pred)
assert metric.summary() == dict(raw=raw, caseless=caseless, unidecode=unidecode, unicase=unicase)
assert metric.summary() == dict(raw=raw, caseless=caseless, anyascii=anyascii, unicase=unicase)

metric.reset()
assert metric.raw == metric.caseless == metric.unidecode == metric.unicase == metric.total == 0
assert metric.raw == metric.caseless == metric.anyascii == metric.unicase == metric.total == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -208,8 +208,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]]],
[["elephant"]],
0.5,
{"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "anyascii": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "anyascii": 1, "unicase": 1},
1,
],
[ # Bad match
Expand All @@ -218,8 +218,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]]],
[["elephant"]],
0.5,
{"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "anyascii": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "anyascii": 0, "unicase": 0},
1,
],
[ # Good match
Expand All @@ -228,8 +228,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5], [0.6, 0.6, 0.7, 0.7]]],
[["€", "e"]],
0.2,
{"raw": 0, "caseless": 0, "unidecode": 1, "unicase": 1},
{"raw": 0, "caseless": 0, "unidecode": 0.5, "unicase": 0.5},
{"raw": 0, "caseless": 0, "anyascii": 1, "unicase": 1},
{"raw": 0, "caseless": 0, "anyascii": 0.5, "unicase": 0.5},
0.13,
],
[ # No preds on 2nd sample
Expand All @@ -238,8 +238,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]], None],
[["elephant"], []],
0.5,
{"raw": 0, "caseless": 0.5, "unidecode": 0, "unicase": 0.5},
{"raw": 0, "caseless": 1, "unidecode": 0, "unicase": 1},
{"raw": 0, "caseless": 0.5, "anyascii": 0, "unicase": 0.5},
{"raw": 0, "caseless": 1, "anyascii": 0, "unicase": 1},
1,
],
],
Expand All @@ -256,7 +256,7 @@ def test_ocr_metric(gt_boxes, gt_words, pred_boxes, pred_words, iou_thresh, reca
assert _mean_iou == mean_iou
metric.reset()
assert metric.num_gts == metric.num_preds == metric.tot_iou == 0
assert metric.raw_matches == metric.caseless_matches == metric.unidecode_matches == metric.unicase_matches == 0
assert metric.raw_matches == metric.caseless_matches == metric.anyascii_matches == metric.unicase_matches == 0
# Shape check
with pytest.raises(AssertionError):
metric.update(
Expand Down
Loading