-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbonsaiTrainer2.py
496 lines (411 loc) · 18.8 KB
/
bonsaiTrainer2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import torch
import numpy as np
import os
import sys
import edgeml_pytorch.utils as utils
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
class BonsaiTrainer:
def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ,
learningRate, useMCHLoss=False, outFile=None, device=None):
'''
bonsaiObj - Initialised Bonsai Object and Graph
lW, lT, lV and lZ are regularisers to Bonsai Params
sW, sT, sV and sZ are sparsity factors to Bonsai Params
learningRate - learningRate for optimizer
useMCHLoss - For choice between HingeLoss vs CrossEntropy
useMCHLoss - True - MultiClass - multiClassHingeLoss
useMCHLoss - False - MultiClass - crossEntropyLoss
'''
self.bonsaiObj = bonsaiObj
self.lW = lW
self.lV = lV
self.lT = lT
self.lZ = lZ
self.sW = sW
self.sV = sV
self.sT = sT
self.sZ = sZ
if device is None:
self.device = "cpu"
else:
self.device = device
self.useMCHLoss = useMCHLoss
if outFile is not None:
print("Outfile : ", outFile)
self.outFile = open(outFile, 'w')
else:
self.outFile = sys.stdout
self.learningRate = learningRate
self.assertInit()
self.optimizer = self.optimizer()
if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99:
self.isDenseTraining = True
else:
self.isDenseTraining = False
def loss(self, logits, labels):
'''
Loss function for given Bonsai Obj
'''
regLoss = 0.5 * (self.lZ * (torch.norm(self.bonsaiObj.Z)**2) +
self.lW * (torch.norm(self.bonsaiObj.W)**2) +
self.lV * (torch.norm(self.bonsaiObj.V)**2) +
self.lT * (torch.norm(self.bonsaiObj.T))**2)
if (self.bonsaiObj.numClasses > 2):
if self.useMCHLoss is True:
marginLoss = utils.multiClassHingeLoss(logits, labels)
else:
marginLoss = utils.crossEntropyLoss(logits, labels)
loss = marginLoss + regLoss
else:
marginLoss = utils.binaryHingeLoss(logits, labels)
loss = marginLoss + regLoss
return loss, marginLoss, regLoss
def optimizer(self):
'''
Optimizer for Bonsai Params
'''
optimizer = torch.optim.Adam(
self.bonsaiObj.parameters(), lr=self.learningRate)
return optimizer
def accuracy(self, logits, labels):
'''
Accuracy fucntion to evaluate accuracy when needed
'''
if (self.bonsaiObj.numClasses > 2):
correctPredictions = (logits.argmax(dim=1) == labels.argmax(dim=1))
accuracy = torch.mean(correctPredictions.float())
else:
pred = (torch.cat((torch.zeros(logits.shape),
logits), 1)).argmax(dim=1)
accuracy = torch.mean((labels.view(-1).long() == pred).float())
return accuracy
def classificationReport(self, logits, labels):
pred = (torch.cat((torch.zeros(logits.shape),
logits), 1)).argmax(dim=1)
return classification_report(labels, pred, output_dict=True)
def confusion_matrix_FAR(self, logits, labels):
pred = (torch.cat((torch.zeros(logits.shape),
logits), 1)).argmax(dim=1)
CM = confusion_matrix(labels, pred)
TN = CM[0][0]
FN = CM[1][0]
TP = CM[1][1]
FP = CM[0][1]
FAR = FP/(FP+TN)
return CM, FAR
def f1(self, logits, labels):
'''
f1 score function to evaluate f1 when needed
'''
# print("logits:", logits, logits.shape)
# print("labels:", labels, labels.shape)
if (self.bonsaiObj.numClasses > 2): # doesnt work for multi-class
correct = (logits.argmax(dim=1) == labels.argmax(dim=1))
pred = torch.zeros(logits.shape)
pred[logits.argmax(dim=1),:] = 1
else:
pred = (torch.cat((torch.zeros(logits.shape),
logits), 1)).argmax(dim=1)
# print("pred:", pred, pred.shape)
f1score = f1_score(labels, pred)
return f1score
def runHardThrsd(self):
'''
Function to run the IHT routine on Bonsai Obj
'''
currW = self.bonsaiObj.W.data
currV = self.bonsaiObj.V.data
currZ = self.bonsaiObj.Z.data
currT = self.bonsaiObj.T.data
__thrsdW = utils.hardThreshold(currW.cpu(), self.sW)
__thrsdV = utils.hardThreshold(currV.cpu(), self.sV)
__thrsdZ = utils.hardThreshold(currZ.cpu(), self.sZ)
__thrsdT = utils.hardThreshold(currT.cpu(), self.sT)
self.bonsaiObj.W.data = torch.FloatTensor(
__thrsdW).to(self.device)
self.bonsaiObj.V.data = torch.FloatTensor(
__thrsdV).to(self.device)
self.bonsaiObj.Z.data = torch.FloatTensor(
__thrsdZ).to(self.device)
self.bonsaiObj.T.data = torch.FloatTensor(
__thrsdT).to(self.device)
self.__thrsdW = torch.FloatTensor(
__thrsdW.detach().clone()).to(self.device)
self.__thrsdV = torch.FloatTensor(
__thrsdV.detach().clone()).to(self.device)
self.__thrsdZ = torch.FloatTensor(
__thrsdZ.detach().clone()).to(self.device)
self.__thrsdT = torch.FloatTensor(
__thrsdT.detach().clone()).to(self.device)
def runSparseTraining(self):
'''
Function to run the Sparse Retraining routine on Bonsai Obj
'''
currW = self.bonsaiObj.W.data
currV = self.bonsaiObj.V.data
currZ = self.bonsaiObj.Z.data
currT = self.bonsaiObj.T.data
newW = utils.copySupport(self.__thrsdW, currW)
newV = utils.copySupport(self.__thrsdV, currV)
newZ = utils.copySupport(self.__thrsdZ, currZ)
newT = utils.copySupport(self.__thrsdT, currT)
self.bonsaiObj.W.data = newW
self.bonsaiObj.V.data = newV
self.bonsaiObj.Z.data = newZ
self.bonsaiObj.T.data = newT
def assertInit(self):
err = "sparsity must be between 0 and 1"
assert self.sW >= 0 and self.sW <= 1, "W " + err
assert self.sV >= 0 and self.sV <= 1, "V " + err
assert self.sZ >= 0 and self.sZ <= 1, "Z " + err
assert self.sT >= 0 and self.sT <= 1, "T " + err
def saveParams(self, currDir):
'''
Function to save Parameter matrices into a given folder
'''
paramDir = currDir + '/'
np.save(paramDir + "W.npy", self.bonsaiObj.W.data.cpu())
np.save(paramDir + "V.npy", self.bonsaiObj.V.data.cpu())
np.save(paramDir + "T.npy", self.bonsaiObj.T.data.cpu())
np.save(paramDir + "Z.npy", self.bonsaiObj.Z.data.cpu())
hyperParamDict = {'dataDim': self.bonsaiObj.dataDimension,
'projDim': self.bonsaiObj.projectionDimension,
'numClasses': self.bonsaiObj.numClasses,
'depth': self.bonsaiObj.treeDepth,
'sigma': self.bonsaiObj.sigma}
hyperParamFile = paramDir + 'hyperParam.npy'
np.save(hyperParamFile, hyperParamDict)
def saveParamsForSeeDot(self, currDir):
'''
Function to save Parameter matrices into a given folder for SeeDot compiler
'''
seeDotDir = currDir + '/SeeDot/'
if os.path.isdir(seeDotDir) is False:
try:
os.mkdir(seeDotDir)
except OSError:
print("Creation of the directory %s failed" %
seeDotDir)
np.savetxt(seeDotDir + "W",
utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.W.data.cpu(),
self.bonsaiObj.numClasses,
self.bonsaiObj.totalNodes),
delimiter="\t")
np.savetxt(seeDotDir + "V",
utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.V.data.cpu(),
self.bonsaiObj.numClasses,
self.bonsaiObj.totalNodes),
delimiter="\t")
np.savetxt(seeDotDir + "T", self.bonsaiObj.T.data.cpu(), delimiter="\t")
np.savetxt(seeDotDir + "Z", self.bonsaiObj.Z.data.cpu(), delimiter="\t")
np.savetxt(seeDotDir + "Sigma",
np.array([self.bonsaiObj.sigma]), delimiter="\t")
def loadModel(self, currDir):
'''
Load the Saved model and load it to the model using constructor
Returns two dict one for params and other for hyperParams
'''
paramDir = currDir + '/'
paramDict = {}
paramDict['W'] = np.load(paramDir + "W.npy")
paramDict['V'] = np.load(paramDir + "V.npy")
paramDict['T'] = np.load(paramDir + "T.npy")
paramDict['Z'] = np.load(paramDir + "Z.npy")
hyperParamDict = np.load(paramDir + "hyperParam.npy").item()
return paramDict, hyperParamDict
# Function to get aimed model size
def getModelSize(self):
'''
Function to get aimed model size
'''
nnzZ, sizeZ, sparseZ = utils.estimateNNZ(self.bonsaiObj.Z, self.sZ)
nnzW, sizeW, sparseW = utils.estimateNNZ(self.bonsaiObj.W, self.sW)
nnzV, sizeV, sparseV = utils.estimateNNZ(self.bonsaiObj.V, self.sV)
nnzT, sizeT, sparseT = utils.estimateNNZ(self.bonsaiObj.T, self.sT)
totalnnZ = (nnzZ + nnzT + nnzV + nnzW)
totalSize = (sizeZ + sizeW + sizeV + sizeT)
hasSparse = (sparseW or sparseV or sparseT or sparseZ)
return totalnnZ, totalSize, hasSparse
def train(self, batchSize, totalEpochs,
Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir):
'''
The Dense - IHT - Sparse Retrain Routine for Bonsai Training
'''
resultFile = open(dataDir + '/PyTorchBonsaiResults.txt', 'a+')
numIters = Xtrain.shape[0] / batchSize
totalBatches = numIters * totalEpochs
self.sigmaI = 1
counter = 0
if self.bonsaiObj.numClasses > 2:
trimlevel = 15
else:
trimlevel = 5
ihtDone = 0
maxTestAcc = -10000
finalF1 = -10000
finalTrainLoss = -10000
finalTrainAcc = -10000
finalClassificationReport = None
finalFAR = -10000
finalCM = None
if self.isDenseTraining is True:
ihtDone = 1
self.sigmaI = 1
itersInPhase = 0
header = '*' * 20
for i in range(totalEpochs):
print("\nEpoch Number: " + str(i), file=self.outFile)
'''
trainAcc -> For Classification, it is 'Accuracy'.
'''
trainAcc = 0.0
trainLoss = 0.0
numIters = int(numIters)
for j in range(numIters):
if counter == 0:
msg = " Dense Training Phase Started "
print("\n%s%s%s\n" %
(header, msg, header), file=self.outFile)
# Updating the indicator sigma
if ((counter == 0) or (counter == int(totalBatches / 3.0)) or
(counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False):
self.sigmaI = 1
itersInPhase = 0
elif (itersInPhase % 100 == 0):
indices = np.random.choice(Xtrain.shape[0], 100)
batchX = Xtrain[indices, :]
batchY = Ytrain[indices, :]
batchY = np.reshape(
batchY, [-1, self.bonsaiObj.numClasses])
Teval = self.bonsaiObj.T.data
Xcapeval = (torch.matmul(self.bonsaiObj.Z, torch.t(
batchX.to(self.device))) / self.bonsaiObj.projectionDimension).data
sum_tr = 0.0
for k in range(0, self.bonsaiObj.internalNodes):
sum_tr += (
np.sum(np.abs(np.dot(Teval[k].cpu(), Xcapeval.cpu()))))
if(self.bonsaiObj.internalNodes > 0):
sum_tr /= (100 * self.bonsaiObj.internalNodes)
sum_tr = 0.1 / sum_tr
else:
sum_tr = 0.1
sum_tr = min(
1000, sum_tr * (2**(float(itersInPhase) /
(float(totalBatches) / 30.0))))
self.sigmaI = sum_tr
itersInPhase += 1
batchX = Xtrain[j * batchSize:(j + 1) * batchSize]
batchY = Ytrain[j * batchSize:(j + 1) * batchSize]
batchY = np.reshape(
batchY, [-1, self.bonsaiObj.numClasses])
self.optimizer.zero_grad()
logits, _ = self.bonsaiObj(batchX.to(self.device), self.sigmaI)
batchLoss, _, _ = self.loss(logits, batchY.to(self.device))
batchAcc = self.accuracy(logits, batchY.to(self.device))
batchLoss.backward()
self.optimizer.step()
# Classification.
trainAcc += batchAcc.item()
trainLoss += batchLoss.item()
# Training routine involving IHT and sparse retraining
if (counter >= int(totalBatches / 3.0) and
(counter < int(2 * totalBatches / 3.0)) and
counter % trimlevel == 0 and
self.isDenseTraining is False):
self.runHardThrsd()
if ihtDone == 0:
msg = " IHT Phase Started "
print("\n%s%s%s\n" %
(header, msg, header), file=self.outFile)
ihtDone = 1
elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and
(counter < int(2 * totalBatches / 3.0)) and
counter % trimlevel != 0 and
self.isDenseTraining is False) or
(counter >= int(2 * totalBatches / 3.0) and
self.isDenseTraining is False)):
self.runSparseTraining()
if counter == int(2 * totalBatches / 3.0):
msg = " Sparse Retraining Phase Started "
print("\n%s%s%s\n" %
(header, msg, header), file=self.outFile)
counter += 1
print("\nClassification Train Loss: " + str(trainLoss / numIters) +
"\nTraining accuracy (Classification): " +
str(trainAcc / numIters),
file=self.outFile)
#####################################
finalTrainAcc = trainAcc / numIters
finalTrainLoss = trainLoss / numIters
oldSigmaI = self.sigmaI
self.sigmaI = 1e9
###################HERE####################################
logits, _ = self.bonsaiObj(Xtest.to(self.device), self.sigmaI)
testLoss, marginLoss, regLoss = self.loss(
logits, Ytest.to(self.device))
testAcc = self.accuracy(logits, Ytest.to(self.device)).item()
testf1 = self.f1(logits, Ytest.to(self.device))
testclass = self.classificationReport(logits, Ytest.to(self.device))
CM, FAR = self.confusion_matrix_FAR(logits, Ytest.to(self.device))
if ihtDone == 0:
maxTestAcc = -10000
maxTestAccEpoch = i
else:
if maxTestAcc <= testAcc:
maxTestAccEpoch = i
maxTestAcc = testAcc
self.saveParams(currDir)
self.saveParamsForSeeDot(currDir)
print("Test accuracy %g" % testAcc, file=self.outFile)
print("Test F1 ", testf1, file=self.outFile)
print("Test False Alarm Rate ", FAR, file=self.outFile)
print("Confusion Matrix \n", CM, file=self.outFile)
print("Classification Report \n", testclass, file=self.outFile)
#####################################
testAcc = testAcc
maxTestAcc = maxTestAcc
finalF1 = testf1
finalClassificationReport = testclass
finalFAR = FAR
finalCM = CM
print("MarginLoss + RegLoss: " + str(marginLoss.item()) + " + " +
str(regLoss.item()) + " = " + str(testLoss.item()) + "\n",
file=self.outFile)
self.outFile.flush()
self.sigmaI = oldSigmaI
# sigmaI has to be set to infinity to ensure
# only a single path is used in inference
self.sigmaI = 1e9
print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " +
str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " +
str(self.getModelSize()[2]) + "\n", file=self.outFile)
print("For Classification, Maximum Test accuracy at compressed" +
" model size(including early stopping): " +
str(maxTestAcc) + " at Epoch: " +
str(maxTestAccEpoch + 1) + "\nFinal Test" +
" Accuracy: " + str(testAcc), file=self.outFile)
resultFile.write("MaxTestAcc: " + str(maxTestAcc) +
" at Epoch(totalEpochs): " +
str(maxTestAccEpoch + 1) +
"(" + str(totalEpochs) + ")" + " ModelSize: " +
str(float(self.getModelSize()[1]) / 1024.0) +
" KB hasSparse: " + str(self.getModelSize()[2]) +
" Param Directory: " +
str(os.path.abspath(currDir)) + "\n")
##############################################################
finalModelSize = float(self.getModelSize()[1]) / 1024.0
print("The Model Directory: " + currDir + "\n")
resultFile.close()
self.outFile.flush()
if self.outFile is not sys.stdout:
self.outFile.close()
finalClassificationReport['train loss'] = finalTrainLoss
finalClassificationReport['train acc'] = finalTrainAcc
finalClassificationReport['test f1'] = finalF1
finalClassificationReport['model size'] = finalModelSize
finalClassificationReport['test far'] = finalFAR
return(finalClassificationReport, finalCM)