From 57ea2d6e6121e1e45ec3f3e4f04e750267f6e71c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 08:50:32 +0100 Subject: [PATCH 1/4] add prototype dataset --- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/fer2013.py | 80 +++++++++++++++++++ .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_resource.py | 14 ++++ 4 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 torchvision/prototype/datasets/_builtin/fer2013.py diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 62abc3119f6..9971ab717b2 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -2,6 +2,7 @@ from .celeba import CelebA from .cifar import Cifar10, Cifar100 from .coco import Coco +from .fer2013 import FER2013 from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .sbd import SBD diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py new file mode 100644 index 00000000000..2d9bd713990 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -0,0 +1,80 @@ +import functools +import io +from typing import Any, Callable, Dict, List, Optional, Union, cast + +import torch +from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser +from torchvision.prototype.datasets.decoder import raw +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + OnlineResource, + DatasetType, + KaggleDownloadResource, +) +from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, + hint_shuffling, + image_buffer_from_array, +) +from torchvision.prototype.features import Label, Image + + +class FER2013(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "fer2013", + type=DatasetType.RAW, + homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", + categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), + valid_options=dict(split=("train", "test")), + ) + + _CHECKSUMS = { + "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", + "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", + } + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + archive = KaggleDownloadResource( + cast(str, self.info.homepage), + file_name=f"{config.split}.csv.zip", + sha256=self._CHECKSUMS[config.split], + ) + return [archive] + + def _collate_and_decode_sample( + self, + data: Dict[str, Any], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48) + label_id = data.get("emotion") + label_idx = int(label_id) if label_id is not None else None + + image: Union[Image, io.BytesIO] + if decoder is raw: + image = Image(raw_image) + else: + image_buffer = image_buffer_from_array(raw_image.numpy()) + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] + + return dict( + image=image, + label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + dp = resource_dps[0] + dp = CSVDictParser(dp) + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 92bcffc0cdb..bde05c49cb1 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset from ._query import SampleQuery -from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource +from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 94603bfc81e..20f794075b6 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -176,3 +176,17 @@ def _download(self, root: pathlib.Path) -> NoReturn: f"Please follow the instructions below and place it in {root}\n\n" f"{self.instructions}" ) + + +class KaggleDownloadResource(ManualDownloadResource): + def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None: + instructions = "\n".join( + ( + "1. Register and login at https://www.kaggle.com", + f"2. Navigate to {challenge_url}", + "3. Click 'Join Competition' and follow the instructions there", + "4. Navigate to the 'Data' tab", + f"5. Select {file_name} in the 'Data Explorer' and click the download button", + ) + ) + super().__init__(instructions, file_name=file_name, **kwargs) From e647f45ce39a50afeab65fd641e4305822319cfb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 18:43:16 +0100 Subject: [PATCH 2/4] add old style dataset --- docs/source/datasets.rst | 1 + test/test_datasets.py | 34 ++++++++++++++ torchvision/datasets/__init__.py | 2 + torchvision/datasets/fer2013.py | 78 ++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+) create mode 100644 torchvision/datasets/fer2013.py diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 7f09ff245ca..753d019dffb 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``caltech101`` exists or will be saved to if download is set to True. + split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _RESOURCES = { + "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), + "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), + } + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) + super().__init__(root, transform=transform, target_transform=target_transform) + + with open(self._verify_integrity(), "r", newline="") as file: + self._samples = [ + ( + torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), + int(row["emotion"]) if "emotion" in row else None, + ) + for row in csv.DictReader(file) + ] + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_tensor, target = self._samples[idx] + image = Image.fromarray(image_tensor.numpy()) + + if self.transform is not None: + image = self.transform(image) + + if self.target_transform is not None: + target = self.target_transform(target) + + return image, target + + def _verify_integrity(self): + base_folder = os.path.join(self.root, type(self).__name__.lower()) + file_name, md5 = self._RESOURCES[self._split] + file = os.path.join(base_folder, file_name) + if not check_integrity(file, md5=md5): + raise RuntimeError( + f"{file_name} not found in {base_folder} or corrupted. " + f"You can download it from " + f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" + ) + return file + + def extra_repr(self) -> str: + return f"split={self._split}" From 41df18eca638805b86da45e6770016d295b17169 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 5 Jan 2022 14:03:03 +0100 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Nicolas Hug --- torchvision/datasets/fer2013.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/fer2013.py b/torchvision/datasets/fer2013.py index d76fe3b424b..0f54e1c4935 100644 --- a/torchvision/datasets/fer2013.py +++ b/torchvision/datasets/fer2013.py @@ -16,7 +16,7 @@ class FER2013(VisionDataset): Args: root (string): Root directory of dataset where directory - ``caltech101`` exists or will be saved to if download is set to True. + ``root/fer2013`` exists. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` @@ -63,7 +63,7 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]: return image, target def _verify_integrity(self): - base_folder = os.path.join(self.root, type(self).__name__.lower()) + base_folder = os.path.join(self.root, "fer2013") file_name, md5 = self._RESOURCES[self._split] file = os.path.join(base_folder, file_name) if not check_integrity(file, md5=md5): From 68982c6ae14a06f939b788954faf0c1999bccccc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 5 Jan 2022 14:27:37 +0100 Subject: [PATCH 4/4] refactor integrity check --- torchvision/datasets/fer2013.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/torchvision/datasets/fer2013.py b/torchvision/datasets/fer2013.py index 0f54e1c4935..60cbfd9bf28 100644 --- a/torchvision/datasets/fer2013.py +++ b/torchvision/datasets/fer2013.py @@ -1,6 +1,5 @@ import csv -import os -import os.path +import pathlib from typing import Any, Callable, Optional, Tuple import torch @@ -38,7 +37,17 @@ def __init__( self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) super().__init__(root, transform=transform, target_transform=target_transform) - with open(self._verify_integrity(), "r", newline="") as file: + base_folder = pathlib.Path(self.root) / "fer2013" + file_name, md5 = self._RESOURCES[self._split] + data_file = base_folder / file_name + if not check_integrity(str(data_file), md5=md5): + raise RuntimeError( + f"{file_name} not found in {base_folder} or corrupted. " + f"You can download it from " + f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" + ) + + with open(data_file, "r", newline="") as file: self._samples = [ ( torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), @@ -62,17 +71,5 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]: return image, target - def _verify_integrity(self): - base_folder = os.path.join(self.root, "fer2013") - file_name, md5 = self._RESOURCES[self._split] - file = os.path.join(base_folder, file_name) - if not check_integrity(file, md5=md5): - raise RuntimeError( - f"{file_name} not found in {base_folder} or corrupted. " - f"You can download it from " - f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" - ) - return file - def extra_repr(self) -> str: return f"split={self._split}"