forked from corenel/pytorch-adda
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathusps.py
136 lines (114 loc) · 4.56 KB
/
usps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
"""
import gzip
import os
import pickle
import urllib
import numpy as np
import torch
import torch.utils.data as data
from torchvision import datasets, transforms
import params
class USPS(data.Dataset):
"""USPS Dataset.
Args:
root (string): Root directory of dataset where dataset file exist.
train (bool, optional): If True, resample from dataset randomly.
download (bool, optional): If true, downloads the dataset
from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in
an PIL image and returns a transformed version.
E.g, ``transforms.RandomCrop``
"""
url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
def __init__(self, root, train=True, transform=None, download=False):
"""Init USPS dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "usps_28x28.pkl"
self.train = train
# Num of Train = 7438, Num ot Test 1860
self.transform = transform
self.dataset_size = None
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
self.train_data, self.train_labels = self.load_samples()
if self.train:
total_num_samples = self.train_labels.shape[0]
indices = np.arange(total_num_samples)
np.random.shuffle(indices)
self.train_data = self.train_data[indices[0:self.dataset_size], ::]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
self.train_data *= 255.0
self.train_data = self.train_data.transpose(
(0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, label = self.train_data[index, ::], self.train_labels[index]
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([np.int64(label).item()])
# label = torch.FloatTensor([label.item()])
return img, label
def __len__(self):
"""Return size of dataset."""
return self.dataset_size
def _check_exists(self):
"""Check if dataset is download and in right place."""
return os.path.exists(os.path.join(self.root, self.filename))
def download(self):
"""Download dataset."""
filename = os.path.join(self.root, self.filename)
dirname = os.path.dirname(filename)
if not os.path.isdir(dirname):
os.makedirs(dirname)
if os.path.isfile(filename):
return
print("Download %s to %s" % (self.url, os.path.abspath(filename)))
urllib.request.urlretrieve(self.url, filename)
print("[DONE]")
return
def load_samples(self):
"""Load sample images from dataset."""
filename = os.path.join(self.root, self.filename)
f = gzip.open(filename, "rb")
data_set = pickle.load(f, encoding="bytes")
f.close()
if self.train:
images = data_set[0][0]
labels = data_set[0][1]
self.dataset_size = labels.shape[0]
else:
images = data_set[1][0]
labels = data_set[1][1]
self.dataset_size = labels.shape[0]
return images, labels
def get_usps(train):
"""Get USPS dataset loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])
# dataset and data loader
usps_dataset = USPS(root=params.data_root,
train=train,
transform=pre_process,
download=True)
usps_data_loader = torch.utils.data.DataLoader(
dataset=usps_dataset,
batch_size=params.batch_size,
shuffle=True)
return usps_data_loader