Skip to content

Commit

Permalink
Add DTD dataset (#5115)
Browse files Browse the repository at this point in the history
* add DTD as prototype dataset

* add old style dataset

* add test for old dataset

* fix tests for windows

* add dataset to docs

* remove properties and use pathlib

* Apply suggestions from code review

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>

* fold -> partition

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
pmeier and NicolasHug authored Jan 5, 2022
1 parent df628c4 commit 5c9c835
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Cityscapes
CocoCaptions
CocoDetection
DTD
EMNIST
FakeData
FashionMNIST
Expand Down
36 changes: 36 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,5 +2205,41 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes * n_samples_per_class)


class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same
partition=(1, 5, 10),
)

def inject_fake_data(self, tmpdir: str, config):
data_folder = pathlib.Path(tmpdir) / "dtd" / "dtd"

num_images_per_class = 3
image_folder = data_folder / "images"
image_files = []
for cls in ("banded", "marbled", "zigzagged"):
image_files.extend(
datasets_utils.create_image_folder(
image_folder,
cls,
file_name_fn=lambda idx: f"{cls}_{idx:04d}.jpg",
num_examples=num_images_per_class,
)
)

meta_folder = data_folder / "labels"
meta_folder.mkdir()
image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files]
image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2)
with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file:
file.write("\n".join(image_ids_in_config) + "\n")

return len(image_ids_in_config)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes
from .coco import CocoCaptions, CocoDetection
from .dtd import DTD
from .fakedata import FakeData
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
Expand Down Expand Up @@ -79,4 +80,5 @@
"FlyingThings3D",
"HD1K",
"Food101",
"DTD",
)
100 changes: 100 additions & 0 deletions torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import pathlib
from typing import Optional, Callable

import PIL.Image

from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset


class DTD(VisionDataset):
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
.. note::
The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"

def __init__(
self,
root: str,
split: str = "train",
partition: int = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
raise ValueError(
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
f"but got {partition} instead"
)
self._partition = partition

super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
self._data_folder = self._base_folder / "dtd"
self._meta_folder = self._data_folder / "labels"
self._images_folder = self._data_folder / "images"

if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

self._image_files = []
classes = []
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
for line in file:
cls, name = line.strip().split("/")
self._image_files.append(self._images_folder.joinpath(cls, name))
classes.append(cls)

self.classes = sorted(set(classes))
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._labels = [self.class_to_idx[cls] for cls in classes]

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx):
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def extra_repr(self) -> str:
return f"split={self._split}, partition={self._partition}"

def _check_exists(self) -> bool:
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)

def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .coco import Coco
from .dtd import DTD
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
Expand Down
47 changes: 47 additions & 0 deletions torchvision/prototype/datasets/_builtin/dtd.categories
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
banded
blotchy
braided
bubbly
bumpy
chequered
cobwebbed
cracked
crosshatched
crystalline
dotted
fibrous
flecked
freckled
frilly
gauzy
grid
grooved
honeycombed
interlaced
knitted
lacelike
lined
marbled
matted
meshed
paisley
perforated
pitted
pleated
polka-dotted
porous
potholed
scaly
smeared
spiralled
sprinkled
stained
stratified
striped
studded
swirly
veined
waffled
woven
wrinkled
zigzagged
130 changes: 130 additions & 0 deletions torchvision/prototype/datasets/_builtin/dtd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
Demultiplexer,
LineReader,
CSVParser,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
path_comparator,
getitem,
)
from torchvision.prototype.features import Label


class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"dtd",
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
valid_options=dict(
split=("train", "test", "val"),
fold=tuple(str(fold) for fold in range(1, 11)),
),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
decompress=True,
)
return [archive]

def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt":
return 1

return 0
elif path.parents[1].name == "images":
return 2
else:
return None

def _image_key_fn(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return str(path.relative_to(path.parents[1]))

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
(_, joint_categories_data), image_data = data
_, *joint_categories = joint_categories_data
path, buffer = image_data

category = pathlib.Path(path).parent.name

return dict(
joint_categories={category for category in joint_categories if category},
label=Label(self.info.categories.index(category), category=category),
path=path,
image=decoder(buffer) if decoder else buffer,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]

splits_dp, joint_categories_dp, images_dp = Demultiplexer(
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
splits_dp = hint_sharding(splits_dp)

joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")

dp = IterKeyZipper(
splits_dp,
joint_categories_dp,
key_fn=getitem(),
ref_key_fn=getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == 2

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_images)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})

0 comments on commit 5c9c835

Please sign in to comment.