Skip to content

Commit

Permalink
[references] Add eval recognition and update eval detection scripts (#…
Browse files Browse the repository at this point in the history
…933)

* add recognition eval scripts

* update CI job and detection scripts

* update CI job

* speedup reco eval script CI

* rename CI job names
  • Loading branch information
felixdittrich92 authored Jun 13, 2022
1 parent 97da310 commit 210ecc4
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 10 deletions.
88 changes: 88 additions & 0 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,50 @@ jobs:
name: Train for a short epoch (PT)
run: python references/recognition/train_pytorch.py crnn_mobilenet_v3_small --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1

evaluate-text-recognition:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.8]
framework: [tensorflow, pytorch]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}
architecture: x64
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}
- if: matrix.framework == 'tensorflow'
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- if: matrix.framework == 'tensorflow'
name: Evaluate text recognition (TF)
run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset IIIT5K
- if: matrix.framework == 'pytorch'
name: Evaluate text recognition (PT)
run: python references/recognition/evaluate_pytorch.py crnn_mobilenet_v3_small --dataset IIIT5K

latency-text-recognition:
runs-on: ${{ matrix.os }}
strategy:
Expand Down Expand Up @@ -213,6 +257,50 @@ jobs:
name: Train for a short epoch (PT)
run: python references/detection/train_pytorch.py ./det_set ./det_set db_mobilenet_v3_large -b 2 --epochs 1

evaluate-text-detection:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.8]
framework: [tensorflow, pytorch]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}
architecture: x64
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}
- if: matrix.framework == 'tensorflow'
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- if: matrix.framework == 'tensorflow'
name: Evaluate text detection (TF)
run: python references/detection/evaluate_tensorflow.py db_mobilenet_v3_large
- if: matrix.framework == 'pytorch'
name: Evaluate text detection (PT)
run: python references/detection/evaluate_pytorch.py db_mobilenet_v3_large

latency-text-detection:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
2 changes: 1 addition & 1 deletion doctr/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
Returns:
a list of cropped images
"""
img = np.array(Image.open(img_path))
img = np.array(Image.open(img_path).convert('RGB'))
# Polygon
if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
return extract_rcrops(img, geoms.astype(dtype=int))
Expand Down
15 changes: 10 additions & 5 deletions references/detection/evaluate_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
targets = [t['boxes'] for t in targets]
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_boxes=True)
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_boxes=True)
out = model(images, targets, return_preds=True)
# Compute metric
loc_preds = out['preds']
for boxes_gt, boxes_pred in zip(targets, loc_preds):
Expand Down Expand Up @@ -80,14 +80,19 @@ def main(args):
ds = datasets.__dict__[args.dataset](
train=True,
download=True,
rotated_bbox=args.rotation,
use_polygons=args.rotation,
sample_transforms=T.Resize(input_shape),
)
# Monkeypatch
subfolder = ds.root.split("/")[-2:]
ds.root = str(Path(ds.root).parent.parent)
ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data]
_ds = datasets.__dict__[args.dataset](train=False, rotated_bbox=args.rotation)
_ds = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=args.rotation,
sample_transforms=T.Resize(input_shape),
)
subfolder = _ds.root.split("/")[-2:]
ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data])

Expand Down Expand Up @@ -127,7 +132,7 @@ def main(args):
model = model.cuda()

# Metrics
metric = LocalizationConfusion(rotated_bbox=args.rotation, mask_shape=input_shape)
metric = LocalizationConfusion(use_polygons=args.rotation, mask_shape=input_shape)

print("Running evaluation")
val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp)
Expand Down
13 changes: 9 additions & 4 deletions references/detection/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
for images, targets in tqdm(val_loader):
images = batch_transforms(images)
targets = [t['boxes'] for t in targets]
out = model(images, targets, training=False, return_boxes=True)
out = model(images, targets, training=False, return_preds=True)
# Compute metric
loc_preds = out['preds']
for boxes_gt, boxes_pred in zip(targets, loc_preds):
Expand Down Expand Up @@ -82,14 +82,19 @@ def main(args):
ds = datasets.__dict__[args.dataset](
train=True,
download=True,
rotated_bbox=args.rotation,
use_polygons=args.rotation,
sample_transforms=T.Resize(input_shape[:2]),
)
# Monkeypatch
subfolder = ds.root.split("/")[-2:]
ds.root = str(Path(ds.root).parent.parent)
ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data]
_ds = datasets.__dict__[args.dataset](train=False, rotated_bbox=args.rotation)
_ds = datasets.__dict__[args.dataset](
train=False,
download=True,
use_polygons=args.rotation,
sample_transforms=T.Resize(input_shape[:2]),
)
subfolder = _ds.root.split("/")[-2:]
ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data])

Expand All @@ -106,7 +111,7 @@ def main(args):
batch_transforms = T.Normalize(mean=mean, std=std)

# Metrics
metric = LocalizationConfusion(rotated_bbox=args.rotation, mask_shape=input_shape[:2])
metric = LocalizationConfusion(use_polygons=args.rotation, mask_shape=input_shape[:2])

print("Running evaluation")
val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric)
Expand Down
163 changes: 163 additions & 0 deletions references/recognition/evaluate_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os

os.environ['USE_TORCH'] = '1'

import multiprocessing as mp
import time

import torch
from torch.utils.data import DataLoader, SequentialSampler
from torchvision.transforms import Normalize
from tqdm import tqdm

from doctr import datasets
from doctr import transforms as T
from doctr.datasets import VOCABS
from doctr.models import recognition
from doctr.utils.metrics import TextMatch


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
val_metric.reset()
# Validation loop
val_loss, batch_cnt = 0, 0
for images, targets in tqdm(val_loader):
try:
targets = [t['labels'][0] for t in targets]
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_preds=True)
else:
out = model(images, targets, return_preds=True)
# Compute metric
if len(out['preds']):
words, _ = zip(*out['preds'])
else:
words = []
val_metric.update(targets, words)

val_loss += out['loss'].item()
batch_cnt += 1
except ValueError:
print(f"unexpected symbol/s in targets:\n{targets} \n--> skip batch")
continue

val_loss /= batch_cnt
result = val_metric.summary()
return val_loss, result['raw'], result['unicase']


def main(args):

print(args)

torch.backends.cudnn.benchmark = True

if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())

# Load doctr model
model = recognition.__dict__[args.arch](
pretrained=True if args.resume is None else False,
input_shape=(3, args.input_size, 4 * args.input_size),
vocab=VOCABS[args.vocab],
).eval()

# Resume weights
if isinstance(args.resume, str):
print(f"Resuming {args.resume}")
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint)

st = time.time()
ds = datasets.__dict__[args.dataset](
train=True,
download=True,
recognition_task=True,
use_polygons=args.regular,
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)

_ds = datasets.__dict__[args.dataset](
train=False,
download=True,
recognition_task=True,
use_polygons=args.regular,
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
ds.data.extend([(np_img, target) for np_img, target in _ds.data])

test_loader = DataLoader(
ds,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(ds),
pin_memory=torch.cuda.is_available(),
collate_fn=ds.collate_fn,
)
print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in "
f"{len(test_loader)} batches)")

mean, std = model.cfg['mean'], model.cfg['std']
batch_transforms = Normalize(mean=mean, std=std)

# Metrics
val_metric = TextMatch()

# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if args.device >= torch.cuda.device_count():
raise ValueError("Invalid device index")
# Silent default switch to GPU if available
elif torch.cuda.is_available():
args.device = 0
else:
print("No accessible GPU, targe device set to CPU.")
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()

print("Running evaluation")
val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric, amp=args.amp)
print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")


def parse_args():
import argparse
parser = argparse.ArgumentParser(description='docTR evaluation script for text recognition (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('arch', type=str, help='text-recognition model to evaluate')
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for evaluation')
parser.add_argument('--dataset', type=str, default="FUNSD", help='Dataset to evaluate on')
parser.add_argument('--device', default=None, type=int, help='device')
parser.add_argument('-b', '--batch_size', type=int, default=32, help='batch size for evaluation')
parser.add_argument('--input_size', type=int, default=32, help='input size H for the model, W = 4*H')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--only_regular', dest='regular', action='store_true',
help='test set contains only regular text')
parser.add_argument('--resume', type=str, default=None, help='Checkpoint to resume')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args


if __name__ == "__main__":
args = parse_args()
main(args)
Loading

0 comments on commit 210ecc4

Please sign in to comment.