-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add DTD as prototype dataset * add old style dataset * add test for old dataset * fix tests for windows * add dataset to docs * remove properties and use pathlib * Apply suggestions from code review Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> * fold -> partition Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
- Loading branch information
1 parent
df628c4
commit 5c9c835
Showing
7 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import pathlib | ||
from typing import Optional, Callable | ||
|
||
import PIL.Image | ||
|
||
from .utils import verify_str_arg, download_and_extract_archive | ||
from .vision import VisionDataset | ||
|
||
|
||
class DTD(VisionDataset): | ||
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_. | ||
Args: | ||
root (string): Root directory of the dataset. | ||
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. | ||
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. | ||
.. note:: | ||
The partition only changes which split each image belongs to. Thus, regardless of the selected | ||
partition, combining all splits will result in all images. | ||
download (bool, optional): If True, downloads the dataset from the internet and | ||
puts it in root directory. If dataset is already downloaded, it is not | ||
downloaded again. | ||
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed | ||
version. E.g, ``transforms.RandomCrop``. | ||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | ||
""" | ||
|
||
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" | ||
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
split: str = "train", | ||
partition: int = 1, | ||
download: bool = True, | ||
transform: Optional[Callable] = None, | ||
target_transform: Optional[Callable] = None, | ||
) -> None: | ||
self._split = verify_str_arg(split, "split", ("train", "val", "test")) | ||
if not isinstance(partition, int) and not (1 <= partition <= 10): | ||
raise ValueError( | ||
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " | ||
f"but got {partition} instead" | ||
) | ||
self._partition = partition | ||
|
||
super().__init__(root, transform=transform, target_transform=target_transform) | ||
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() | ||
self._data_folder = self._base_folder / "dtd" | ||
self._meta_folder = self._data_folder / "labels" | ||
self._images_folder = self._data_folder / "images" | ||
|
||
if download: | ||
self._download() | ||
|
||
if not self._check_exists(): | ||
raise RuntimeError("Dataset not found. You can use download=True to download it") | ||
|
||
self._image_files = [] | ||
classes = [] | ||
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: | ||
for line in file: | ||
cls, name = line.strip().split("/") | ||
self._image_files.append(self._images_folder.joinpath(cls, name)) | ||
classes.append(cls) | ||
|
||
self.classes = sorted(set(classes)) | ||
self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) | ||
self._labels = [self.class_to_idx[cls] for cls in classes] | ||
|
||
def __len__(self) -> int: | ||
return len(self._image_files) | ||
|
||
def __getitem__(self, idx): | ||
image_file, label = self._image_files[idx], self._labels[idx] | ||
image = PIL.Image.open(image_file).convert("RGB") | ||
|
||
if self.transform: | ||
image = self.transform(image) | ||
|
||
if self.target_transform: | ||
label = self.target_transform(label) | ||
|
||
return image, label | ||
|
||
def extra_repr(self) -> str: | ||
return f"split={self._split}, partition={self._partition}" | ||
|
||
def _check_exists(self) -> bool: | ||
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) | ||
|
||
def _download(self) -> None: | ||
if self._check_exists(): | ||
return | ||
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
banded | ||
blotchy | ||
braided | ||
bubbly | ||
bumpy | ||
chequered | ||
cobwebbed | ||
cracked | ||
crosshatched | ||
crystalline | ||
dotted | ||
fibrous | ||
flecked | ||
freckled | ||
frilly | ||
gauzy | ||
grid | ||
grooved | ||
honeycombed | ||
interlaced | ||
knitted | ||
lacelike | ||
lined | ||
marbled | ||
matted | ||
meshed | ||
paisley | ||
perforated | ||
pitted | ||
pleated | ||
polka-dotted | ||
porous | ||
potholed | ||
scaly | ||
smeared | ||
spiralled | ||
sprinkled | ||
stained | ||
stratified | ||
striped | ||
studded | ||
swirly | ||
veined | ||
waffled | ||
woven | ||
wrinkled | ||
zigzagged |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import io | ||
import pathlib | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from torchdata.datapipes.iter import ( | ||
IterDataPipe, | ||
Mapper, | ||
Shuffler, | ||
Filter, | ||
IterKeyZipper, | ||
Demultiplexer, | ||
LineReader, | ||
CSVParser, | ||
) | ||
from torchvision.prototype.datasets.utils import ( | ||
Dataset, | ||
DatasetConfig, | ||
DatasetInfo, | ||
HttpResource, | ||
OnlineResource, | ||
DatasetType, | ||
) | ||
from torchvision.prototype.datasets.utils._internal import ( | ||
INFINITE_BUFFER_SIZE, | ||
hint_sharding, | ||
path_comparator, | ||
getitem, | ||
) | ||
from torchvision.prototype.features import Label | ||
|
||
|
||
class DTD(Dataset): | ||
def _make_info(self) -> DatasetInfo: | ||
return DatasetInfo( | ||
"dtd", | ||
type=DatasetType.IMAGE, | ||
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", | ||
valid_options=dict( | ||
split=("train", "test", "val"), | ||
fold=tuple(str(fold) for fold in range(1, 11)), | ||
), | ||
) | ||
|
||
def resources(self, config: DatasetConfig) -> List[OnlineResource]: | ||
archive = HttpResource( | ||
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", | ||
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", | ||
decompress=True, | ||
) | ||
return [archive] | ||
|
||
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: | ||
path = pathlib.Path(data[0]) | ||
if path.parent.name == "labels": | ||
if path.name == "labels_joint_anno.txt": | ||
return 1 | ||
|
||
return 0 | ||
elif path.parents[1].name == "images": | ||
return 2 | ||
else: | ||
return None | ||
|
||
def _image_key_fn(self, data: Tuple[str, Any]) -> str: | ||
path = pathlib.Path(data[0]) | ||
return str(path.relative_to(path.parents[1])) | ||
|
||
def _collate_and_decode_sample( | ||
self, | ||
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]], | ||
*, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> Dict[str, Any]: | ||
(_, joint_categories_data), image_data = data | ||
_, *joint_categories = joint_categories_data | ||
path, buffer = image_data | ||
|
||
category = pathlib.Path(path).parent.name | ||
|
||
return dict( | ||
joint_categories={category for category in joint_categories if category}, | ||
label=Label(self.info.categories.index(category), category=category), | ||
path=path, | ||
image=decoder(buffer) if decoder else buffer, | ||
) | ||
|
||
def _make_datapipe( | ||
self, | ||
resource_dps: List[IterDataPipe], | ||
*, | ||
config: DatasetConfig, | ||
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], | ||
) -> IterDataPipe[Dict[str, Any]]: | ||
archive_dp = resource_dps[0] | ||
|
||
splits_dp, joint_categories_dp, images_dp = Demultiplexer( | ||
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE | ||
) | ||
|
||
splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) | ||
splits_dp = LineReader(splits_dp, decode=True, return_path=False) | ||
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) | ||
splits_dp = hint_sharding(splits_dp) | ||
|
||
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") | ||
|
||
dp = IterKeyZipper( | ||
splits_dp, | ||
joint_categories_dp, | ||
key_fn=getitem(), | ||
ref_key_fn=getitem(0), | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
) | ||
dp = IterKeyZipper( | ||
dp, | ||
images_dp, | ||
key_fn=getitem(0), | ||
ref_key_fn=self._image_key_fn, | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
) | ||
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) | ||
|
||
def _filter_images(self, data: Tuple[str, Any]) -> bool: | ||
return self._classify_archive(data) == 2 | ||
|
||
def _generate_categories(self, root: pathlib.Path) -> List[str]: | ||
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) | ||
dp = Filter(dp, self._filter_images) | ||
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) |