Skip to content

Commit

Permalink
added FashionMNIST dataset (#238)
Browse files Browse the repository at this point in the history
* added FashionMNIST dataset

* documentation

* fixed formatting

* fixed formatting
  • Loading branch information
kashif authored and soumith committed Sep 2, 2017
1 parent 7492fae commit eec5ba4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
20 changes: 11 additions & 9 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Datasets

The following dataset loaders are available:

- `MNIST <#mnist>`__
- `MNIST and FashionMNIST <#mnist>`__
- `COCO (Captioning and Detection) <#coco>`__
- `LSUN Classification <#lsun>`__
- `ImageFolder <#imagefolder>`__
Expand Down Expand Up @@ -77,6 +77,8 @@ MNIST
~~~~~
``dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)``

``dset.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)``

``root``: root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist

``train``: ``True`` - use training set, ``False`` - use test set.
Expand Down Expand Up @@ -390,32 +392,32 @@ For example:
Utils
=====

make\_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale\_each=False, pad\_value=0)
``make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images

normalize=True will shift the image to the range (0, 1),
``normalize=True`` will shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value.

if range=(min, max) where min and max are numbers, then these numbers are used to
if ``range=(min, max)`` where ``min`` and ``max`` are numbers, then these numbers are used to
normalize the image.

scale_each=True will scale each image in the batch of images separately rather than
computing the (min, max) over all images.
``scale_each=True`` will scale each image in the batch of images separately rather than
computing the ``(min, max)`` over all images.

pad_value=<float> sets the value for the padded pixels.
``pad_value=<float>`` sets the value for the padded pixels.

`Example usage is given in this notebook` <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>

save\_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale\_each=False, pad\_value=0)
``save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Saves a given Tensor into an image file.

If given a mini-batch tensor, will save the tensor as a grid of images.

All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
All options after ``filename`` are passed through to ``make_grid``. Refer to it's documentation for
more details
4 changes: 2 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST
from .mnist import MNIST, FashionMNIST
from .svhn import SVHN
from .phototour import PhotoTour
from .fakedata import FakeData

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'CIFAR10', 'CIFAR100', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour')
11 changes: 11 additions & 0 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def download(self):
print('Done!')


class FashionMNIST(MNIST):
"""`Fashion MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
"""
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]


def get_int(b):
return int(codecs.encode(b, 'hex'), 16)

Expand Down

0 comments on commit eec5ba4

Please sign in to comment.