-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathdata_utils.py
62 lines (54 loc) · 1.51 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
from sklearn import utils as skutils
from rng import np_rng, py_rng
def center_crop(x, ph, pw=None):
if pw is None:
pw = ph
h, w = x.shape[:2]
j = int(round((h - ph)/2.))
i = int(round((w - pw)/2.))
return x[j:j+ph, i:i+pw]
def patch(x, ph, pw=None):
if pw is None:
pw = ph
h, w = x.shape[:2]
j = py_rng.randint(0, h-ph)
i = py_rng.randint(0, w-pw)
x = x[j:j+ph, i:i+pw]
return x
def list_shuffle(*data):
idxs = np_rng.permutation(np.arange(len(data[0])))
if len(data) == 1:
return [data[0][idx] for idx in idxs]
else:
return [[d[idx] for idx in idxs] for d in data]
def shuffle(*arrays, **options):
if isinstance(arrays[0][0], basestring):
return list_shuffle(*arrays)
else:
return skutils.shuffle(*arrays, random_state=np_rng)
def OneHot(X, n=None, negative_class=0.):
X = np.asarray(X).flatten()
if n is None:
n = np.max(X) + 1
Xoh = np.ones((len(X), n)) * negative_class
Xoh[np.arange(len(X)), X] = 1.
return Xoh
def iter_data(*data, **kwargs):
size = kwargs.get('size', 128)
try:
n = len(data[0])
except:
n = data[0].shape[0]
batches = n / size
if n % size != 0:
batches += 1
for b in range(batches):
start = b * size
end = (b + 1) * size
if end > n:
end = n
if len(data) == 1:
yield data[0][start:end]
else:
yield tuple([d[start:end] for d in data])