diff --git a/README.rst b/README.rst index 9478b133064..d7067fbad76 100644 --- a/README.rst +++ b/README.rst @@ -168,6 +168,17 @@ STL10 - ``download`` : ``True`` = downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything. + +SVHN +~~~~~ + +``dset.SVHN(root, split='train', transform=None, target_transform=None, download=False)`` + +- ``root`` : root directory of dataset where there is folder ``SVHN`` +- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'extra'`` = Extra training set +- ``download`` : ``True`` = downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, does not do + anything. ImageFolder ~~~~~~~~~~~ diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 330dc8be4e5..e7a2d3bfb30 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -4,9 +4,10 @@ from .cifar import CIFAR10, CIFAR100 from .stl10 import STL10 from .mnist import MNIST +from .svhn import SVHN __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', - 'MNIST', 'STL10') + 'MNIST', 'STL10', 'SVHN') diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py new file mode 100644 index 00000000000..65ab01aab02 --- /dev/null +++ b/torchvision/datasets/svhn.py @@ -0,0 +1,111 @@ +from __future__ import print_function +import torch.utils.data as data +from PIL import Image +import os +import os.path +import errno +import numpy as np +import sys + + +class SVHN(data.Dataset): + url = "" + filename = "" + file_md5 = "" + + split_list = { + 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", + "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], + 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", + "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], + 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", + "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} + + def __init__(self, root, split='train', transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + self.split = split # training set or test set or extra set + + if self.split not in self.split_list: + raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"') + + self.url = self.split_list[split][0] + self.filename = self.split_list[split][1] + self.file_md5 = self.split_list[split][2] + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + # import here rather than at top of file because this is + # an optional dependency for torchvision + import scipy.io as sio + + # reading(loading) mat file as array + loaded_mat = sio.loadmat(os.path.join(root, self.filename)) + + self.data = loaded_mat['X'] + self.labels = loaded_mat['y'] + self.data = np.transpose(self.data, (3, 2, 0, 1)) + + def __getitem__(self, index): + img, target = self.data[index], self.labels[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(np.transpose(img, (1, 2, 0))) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + def _check_integrity(self): + import hashlib + root = self.root + md5 = self.split_list[self.split][2] + fpath = os.path.join(root, self.filename) + if not os.path.isfile(fpath): + return False + md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest() + if md5c != md5: + return False + return True + + def download(self): + from six.moves import urllib + import tarfile + import hashlib + + root = self.root + fpath = os.path.join(root, self.filename) + + try: + os.makedirs(root) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + # downloads file + if os.path.isfile(fpath): + print('Using downloaded file: ' + fpath) + else: + print('Downloading ' + self.url + ' to ' + fpath) + urllib.request.urlretrieve(self.url, fpath) + print ('Downloaded!')