-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataInput.py
executable file
·56 lines (45 loc) · 1.75 KB
/
dataInput.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import matplotlib
# For Mac OS
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import numpy as np
class dataInput:
def __init__(self):
self.trainData, self.testData = self.dataLoad()
self.labelName = ['airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck']
def dataLoad(self):
data = []
for i in range(5):
dataFrame = self.unpickle('./data_set/cifar_10/data_batch_' + str(i+1))
data.append(dataFrame)
testData = self.unpickle('./data_set/cifar_10/test_batch' )
return data, testData
def unpickle(self,file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def dataVisuallization(self, label):
tmpDataIdx = [idx for idx, val in enumerate(self.testData[b'labels']) if val ==label]
tmpData = self.testData[b'data'][tmpDataIdx[np.random.choice(len(tmpDataIdx))],:].reshape([3,32,32])
plt.imshow(tmpData.transpose(1,2,0))
return tmpData
def dataVisuallizationSubplot(self):
for i in range(10):
for q in range(10):
plt.subplot(10,10,((i)*10) + (q+1) )
self.dataVisuallization(q)
if(i==0):
plt.title(self.labelName[q])
plt.show()
if __name__ == '__main__':
dataOb = dataInput()