Skip to content

Commit

Permalink
SVHN dataset for torchvision (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
uridah authored and soumith committed Mar 16, 2017
1 parent c4f4c73 commit 23e0d65
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ STL10
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset already downloaded, does not do
anything.

SVHN
~~~~~

``dset.SVHN(root, split='train', transform=None, target_transform=None, download=False)``

- ``root`` : root directory of dataset where there is folder ``SVHN``
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'extra'`` = Extra training set
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, does not do
anything.

ImageFolder
~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST
from .svhn import SVHN

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'MNIST', 'STL10')
'MNIST', 'STL10', 'SVHN')
111 changes: 111 additions & 0 deletions torchvision/datasets/svhn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys


class SVHN(data.Dataset):
url = ""
filename = ""
file_md5 = ""

split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
"train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}

def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set

if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"')

self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
self.file_md5 = self.split_list[split][2]

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

# import here rather than at top of file because this is
# an optional dependency for torchvision
import scipy.io as sio

# reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(root, self.filename))

self.data = loaded_mat['X']
self.labels = loaded_mat['y']
self.data = np.transpose(self.data, (3, 2, 0, 1))

def __getitem__(self, index):
img, target = self.data[index], self.labels[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.data)

def _check_integrity(self):
import hashlib
root = self.root
md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True

def download(self):
from six.moves import urllib
import tarfile
import hashlib

root = self.root
fpath = os.path.join(root, self.filename)

try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

if self._check_integrity():
print('Files already downloaded and verified')
return

# downloads file
if os.path.isfile(fpath):
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
print ('Downloaded!')

0 comments on commit 23e0d65

Please sign in to comment.