Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added FashionMNIST dataset #238

Merged
merged 4 commits into from
Sep 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Datasets

The following dataset loaders are available:

- `MNIST <#mnist>`__
- `MNIST and FashionMNIST <#mnist>`__
- `COCO (Captioning and Detection) <#coco>`__
- `LSUN Classification <#lsun>`__
- `ImageFolder <#imagefolder>`__
Expand Down Expand Up @@ -77,6 +77,8 @@ MNIST
~~~~~
``dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)``

``dset.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)``

``root``: root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist

``train``: ``True`` - use training set, ``False`` - use test set.
Expand Down Expand Up @@ -390,32 +392,32 @@ For example:
Utils
=====

make\_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale\_each=False, pad\_value=0)
``make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images

normalize=True will shift the image to the range (0, 1),
``normalize=True`` will shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value.

if range=(min, max) where min and max are numbers, then these numbers are used to
if ``range=(min, max)`` where ``min`` and ``max`` are numbers, then these numbers are used to
normalize the image.

scale_each=True will scale each image in the batch of images separately rather than
computing the (min, max) over all images.
``scale_each=True`` will scale each image in the batch of images separately rather than
computing the ``(min, max)`` over all images.

pad_value=<float> sets the value for the padded pixels.
``pad_value=<float>`` sets the value for the padded pixels.

`Example usage is given in this notebook` <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>

save\_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale\_each=False, pad\_value=0)
``save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Saves a given Tensor into an image file.

If given a mini-batch tensor, will save the tensor as a grid of images.

All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
All options after ``filename`` are passed through to ``make_grid``. Refer to it's documentation for
more details
4 changes: 2 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST
from .mnist import MNIST, FashionMNIST
from .svhn import SVHN
from .phototour import PhotoTour
from .fakedata import FakeData

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'CIFAR10', 'CIFAR100', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour')
11 changes: 11 additions & 0 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def download(self):
print('Done!')


class FashionMNIST(MNIST):
"""`Fashion MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
"""
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]


def get_int(b):
return int(codecs.encode(b, 'hex'), 16)

Expand Down