-
Notifications
You must be signed in to change notification settings - Fork 0
/
distillation_playground_cskd.py
102 lines (81 loc) · 3.11 KB
/
distillation_playground_cskd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
from DataLoader import DatasetLoader
import os
from custom_models.models import CustomModels
import torch.optim as optim
from config import get_config
from KD_Lib.KD import CSKD
import glob
import sys
config = get_config()
config.batch_size = 2048
config.dist_val_epochs = 100
dataset = sys.argv[1]
num_classes = 10
if dataset == 'fashion_mnist':
in_channel = 1
elif dataset == 'cifar10':
in_channel = 3
cmi = CustomModels(IN_CHANNEL=in_channel, NUM_OUTPUT=num_classes)
KD_METHOD = 'CSKD'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 3. CSKD - No Teacher Model. Only Student Model
student_models = [
# cmi.init_model('model_25k_w_dw'),
# cmi.init_model('model_25k_wo_dw'),
# cmi.init_model('resnet_18'),
# cmi.init_model('resnet_34'),
# cmi.init_model('resnet_50'),
# cmi.init_model('resnet_101'),
# cmi.init_model('efficientnet-b5'),
cmi.init_model('efficientnet-b7')
# cmi.init_model('model_143k_w_dw'),
# cmi.init_model('model_143k_wo_dw'),
# cmi.init_model('model_340k_w_dw'),
# cmi.init_model('model_340k_wo_dw'),
# cmi.init_model('model_600k_w_dw'),
# cmi.init_model('model_600k_wo_dw'),
# cmi.init_model('model_1M_w_dw'),
# cmi.init_model('model_1M_wo_dw')
]
def getCheckpointModelPath(model_base_dir, model_type):
model_path = os.path.join(model_base_dir, model_type)
model_path = os.path.abspath(glob.glob(f'{model_path}/*/*.pth')[0])
return model_path
# Get the dataset loader
dl = DatasetLoader(ds=dataset)
train_dl, test_dl = dl.getDataLoader(valid=False)
print("Dataset: ", str(dl._name))
# 3. CSKD - Self Knowledge Distillation
# Only Student Model.
# Batch - 1024
for student_model in student_models:
print("----STARTED----")
print(student_model._name)
# Fresh out of the oven, student models
student_model = student_model.to(device)
student_optimizer = optim.Adam(student_model.parameters(), lr=config.learning_rate)
# KD Teacher save model path
#teacher_save_model_pth = os.path.join('kd_models_save', KD_METHOD, '_'+ student_model._name, 'teacher.pth')
student_save_model_pth = os.path.join('kd_models_save', dl._name, KD_METHOD, '_'+ student_model._name, 'student.pth')
dir_name = os.path.dirname(student_save_model_pth)
os.makedirs(dir_name, exist_ok=True)
# Experiment Tensorboard Log Directory
logdir = os.path.join('./Experiments', dl._name, KD_METHOD)
os.makedirs(logdir, exist_ok=True)
print("CSKD Initialized")
distiller = CSKD(teacher_model=None,
student_model=student_model,
train_loader=train_dl,
val_loader=test_dl,
optimizer_teacher=None,
optimizer_student=student_optimizer,
device=device,
log=True,
logdir=logdir)
distiller.train_student(epochs=config.dist_val_epochs, save_model_pth=os.path.join(dir_name, 'student.pt')) # Train the student model
distiller.evaluate()
print(student_model._name)
print("----COMPLETED----")