-
Notifications
You must be signed in to change notification settings - Fork 463
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[references] Add eval recognition and update eval detection scripts (#…
…933) * add recognition eval scripts * update CI job and detection scripts * update CI job * speedup reco eval script CI * rename CI job names
- Loading branch information
1 parent
97da310
commit 210ecc4
Showing
6 changed files
with
412 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright (C) 2022, Mindee. | ||
|
||
# This program is licensed under the Apache License version 2. | ||
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details. | ||
|
||
import os | ||
|
||
os.environ['USE_TORCH'] = '1' | ||
|
||
import multiprocessing as mp | ||
import time | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, SequentialSampler | ||
from torchvision.transforms import Normalize | ||
from tqdm import tqdm | ||
|
||
from doctr import datasets | ||
from doctr import transforms as T | ||
from doctr.datasets import VOCABS | ||
from doctr.models import recognition | ||
from doctr.utils.metrics import TextMatch | ||
|
||
|
||
@torch.no_grad() | ||
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): | ||
# Model in eval mode | ||
model.eval() | ||
# Reset val metric | ||
val_metric.reset() | ||
# Validation loop | ||
val_loss, batch_cnt = 0, 0 | ||
for images, targets in tqdm(val_loader): | ||
try: | ||
targets = [t['labels'][0] for t in targets] | ||
if torch.cuda.is_available(): | ||
images = images.cuda() | ||
images = batch_transforms(images) | ||
if amp: | ||
with torch.cuda.amp.autocast(): | ||
out = model(images, targets, return_preds=True) | ||
else: | ||
out = model(images, targets, return_preds=True) | ||
# Compute metric | ||
if len(out['preds']): | ||
words, _ = zip(*out['preds']) | ||
else: | ||
words = [] | ||
val_metric.update(targets, words) | ||
|
||
val_loss += out['loss'].item() | ||
batch_cnt += 1 | ||
except ValueError: | ||
print(f"unexpected symbol/s in targets:\n{targets} \n--> skip batch") | ||
continue | ||
|
||
val_loss /= batch_cnt | ||
result = val_metric.summary() | ||
return val_loss, result['raw'], result['unicase'] | ||
|
||
|
||
def main(args): | ||
|
||
print(args) | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
if not isinstance(args.workers, int): | ||
args.workers = min(16, mp.cpu_count()) | ||
|
||
# Load doctr model | ||
model = recognition.__dict__[args.arch]( | ||
pretrained=True if args.resume is None else False, | ||
input_shape=(3, args.input_size, 4 * args.input_size), | ||
vocab=VOCABS[args.vocab], | ||
).eval() | ||
|
||
# Resume weights | ||
if isinstance(args.resume, str): | ||
print(f"Resuming {args.resume}") | ||
checkpoint = torch.load(args.resume, map_location='cpu') | ||
model.load_state_dict(checkpoint) | ||
|
||
st = time.time() | ||
ds = datasets.__dict__[args.dataset]( | ||
train=True, | ||
download=True, | ||
recognition_task=True, | ||
use_polygons=args.regular, | ||
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), | ||
) | ||
|
||
_ds = datasets.__dict__[args.dataset]( | ||
train=False, | ||
download=True, | ||
recognition_task=True, | ||
use_polygons=args.regular, | ||
img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), | ||
) | ||
ds.data.extend([(np_img, target) for np_img, target in _ds.data]) | ||
|
||
test_loader = DataLoader( | ||
ds, | ||
batch_size=args.batch_size, | ||
drop_last=False, | ||
num_workers=args.workers, | ||
sampler=SequentialSampler(ds), | ||
pin_memory=torch.cuda.is_available(), | ||
collate_fn=ds.collate_fn, | ||
) | ||
print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " | ||
f"{len(test_loader)} batches)") | ||
|
||
mean, std = model.cfg['mean'], model.cfg['std'] | ||
batch_transforms = Normalize(mean=mean, std=std) | ||
|
||
# Metrics | ||
val_metric = TextMatch() | ||
|
||
# GPU | ||
if isinstance(args.device, int): | ||
if not torch.cuda.is_available(): | ||
raise AssertionError("PyTorch cannot access your GPU. Please investigate!") | ||
if args.device >= torch.cuda.device_count(): | ||
raise ValueError("Invalid device index") | ||
# Silent default switch to GPU if available | ||
elif torch.cuda.is_available(): | ||
args.device = 0 | ||
else: | ||
print("No accessible GPU, targe device set to CPU.") | ||
if torch.cuda.is_available(): | ||
torch.cuda.set_device(args.device) | ||
model = model.cuda() | ||
|
||
print("Running evaluation") | ||
val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric, amp=args.amp) | ||
print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") | ||
|
||
|
||
def parse_args(): | ||
import argparse | ||
parser = argparse.ArgumentParser(description='docTR evaluation script for text recognition (PyTorch)', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
|
||
parser.add_argument('arch', type=str, help='text-recognition model to evaluate') | ||
parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for evaluation') | ||
parser.add_argument('--dataset', type=str, default="FUNSD", help='Dataset to evaluate on') | ||
parser.add_argument('--device', default=None, type=int, help='device') | ||
parser.add_argument('-b', '--batch_size', type=int, default=32, help='batch size for evaluation') | ||
parser.add_argument('--input_size', type=int, default=32, help='input size H for the model, W = 4*H') | ||
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') | ||
parser.add_argument('--only_regular', dest='regular', action='store_true', | ||
help='test set contains only regular text') | ||
parser.add_argument('--resume', type=str, default=None, help='Checkpoint to resume') | ||
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
Oops, something went wrong.