Skip to content

Commit

Permalink
feat: added DocArtefacts dataset (#583)
Browse files Browse the repository at this point in the history
* initial commit

* style: Fixed indt

* Refactor: Removed unused import

* docs: Added entry to documentation

* feat: Enabled FP16

* fix: Fixed the rotated bbox conversion

* fix: Fixed orientation in bbox conversion

* style: Reordered imports

* test: Added corresponding unittests

* test: Fixed coverage

Co-authored-by: fg-mindee <fg@mindee.co>
  • Loading branch information
SiddhantBahuguna and fg-mindee authored Nov 5, 2021
1 parent e96acd1 commit f537551
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Here are all datasets that are available through docTR:
.. autoclass:: CORD
.. autoclass:: OCRDataset
.. autoclass:: CharacterGenerator
.. autoclass:: DocArtefacts


Data Loading
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .classification import *
from .cord import *
from .detection import *
from .doc_artefacts import *
from .funsd import *
from .ocr import *
from .recognition import *
Expand Down
70 changes: 70 additions & 0 deletions doctr/datasets/doc_artefacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np

from .datasets import VisionDataset

__all__ = ['DocArtefacts']


class DocArtefacts(VisionDataset):
"""Dataset containing ....
Example::
>>> from doctr.datasets import DocArtefacts
>>> train_set = DocArtefacts(download=True)
>>> img, target = train_set[0]
Args:
sample_transforms: composable transformations that will be applied to each image
rotated_bbox: whether polygons should be considered as rotated bounding box (instead of straight ones)
**kwargs: keyword arguments from `VisionDataset`.
"""

URL = 'https://github.com/mindee/doctr/releases/download/v0.4.0/artefact_detection-6c401d4d.zip'
SHA256 = '6c401d4d5d4ebaf086c3ed81a7d8142f48161420ab693bf8ac384e413a9c7d19'
FILE_NAME = 'artefact_detection-6c401d4d.zip'

def __init__(
self,
sample_transforms: Optional[Callable[[Any], Any]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

super().__init__(self.URL, self.FILE_NAME, self.SHA256, True, **kwargs)
self.sample_transforms = sample_transforms

# List images
tmp_root = os.path.join(self.root, 'images')
with open(os.path.join(self.root, "labels.json"), "rb") as f:
labels = json.load(f)
self.data: List[Tuple[str, Dict[str, Any]]] = []
img_list = os.listdir(tmp_root)
if len(labels) != len(img_list):
raise AssertionError('the number of images and labels do not match')
np_dtype = np.float16 if self.fp16 else np.float32
for img_name, label in labels.items():
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
boxes = np.asarray([obj['geometry'] for obj in label], dtype=np_dtype)
classes = [obj['label'] for obj in label]
if rotated_bbox:
# box_targets: xmin, ymin, xmax, ymax -> x, y, w, h, alpha = 0
boxes = np.stack((
boxes[:, [0, 2]].mean(axis=1),
boxes[:, [1, 3]].mean(axis=1),
boxes[:, 2] - boxes[:, 0],
boxes[:, 3] - boxes[:, 1],
np.zeros(boxes.shape[0], dtype=np.dtype),
), axis=1)
self.data.append((img_name, dict(boxes=boxes, labels=classes)))
self.root = tmp_root
9 changes: 6 additions & 3 deletions test/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ def test_visiondataset():
['SROIE', False, [512, 512], 360, False],
['CORD', True, [512, 512], 800, True],
['CORD', False, [512, 512], 100, False],
['DocArtefacts', None, [512, 512], 3000, False],
['DocArtefacts', None, [512, 512], 3000, True],
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):

kwargs = {} if train is None else {"train": train}
ds = datasets.__dict__[dataset_name](
train=train, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate
download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate, **kwargs,
)

assert len(ds) == size
assert repr(ds) == f"{dataset_name}(train={train})"
assert repr(ds) == (f"{dataset_name}()" if train is None else f"{dataset_name}(train={train})")
img, target = ds[0]
assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
Expand All @@ -53,7 +56,7 @@ def test_dataset(dataset_name, train, input_size, size, rotate):
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)

# FP16 checks
ds = datasets.__dict__[dataset_name](train=train, download=True, fp16=True)
ds = datasets.__dict__[dataset_name](download=True, fp16=True, **kwargs)
img, target = ds[0]
assert img.dtype == torch.float16

Expand Down
9 changes: 6 additions & 3 deletions test/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
['SROIE', False, [512, 512], 360, False],
['CORD', True, [512, 512], 800, True],
['CORD', False, [512, 512], 100, False],
['DocArtefacts', None, [512, 512], 3000, False],
['DocArtefacts', None, [512, 512], 3000, True],
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):

kwargs = {} if train is None else {"train": train}
ds = datasets.__dict__[dataset_name](
train=train, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate
download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate, **kwargs,
)

assert len(ds) == size
assert repr(ds) == f"{dataset_name}(train={train})"
assert repr(ds) == (f"{dataset_name}()" if train is None else f"{dataset_name}(train={train})")
img, target = ds[0]
assert isinstance(img, tf.Tensor)
assert img.shape == (*input_size, 3)
Expand All @@ -40,7 +43,7 @@ def test_dataset(dataset_name, train, input_size, size, rotate):
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)

# FP16
ds = datasets.__dict__[dataset_name](train=train, download=True, fp16=True)
ds = datasets.__dict__[dataset_name](download=True, fp16=True, **kwargs)
img, target = ds[0]
assert img.dtype == tf.float16

Expand Down

0 comments on commit f537551

Please sign in to comment.