-
Notifications
You must be signed in to change notification settings - Fork 8
/
train_transformation.py
108 lines (85 loc) · 3.14 KB
/
train_transformation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#
# Copyright (C) 2023 Apple Inc. All rights reserved.
#
from typing import Dict
import yaml
from argparse import ArgumentParser
import torch
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from trainers import TransformationTrainer
from dataset import SubImageFolder
from utils.net_utils import transformation_to_torchscripts
from utils.schedulers import get_policy
from utils.getters import get_model, get_optimizer, get_criteria
def main(config: Dict) -> None:
"""Run training.
:param config: A dictionary with all configurations to run training.
:return:
"""
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True
model = get_model(config.get("arch_params"))
old_model = torch.jit.load(config.get("old_model_path"))
new_model = torch.jit.load(config.get("new_model_path"))
if torch.cuda.is_available():
model = torch.nn.DataParallel(model)
old_model = torch.nn.DataParallel(old_model)
new_model = torch.nn.DataParallel(new_model)
model.to(device)
old_model.to(device)
new_model.to(device)
if config.get("side_info_model_path") is not None:
side_info_model = torch.jit.load(config.get("side_info_model_path"))
if torch.cuda.is_available():
side_info_model = torch.nn.DataParallel(side_info_model)
side_info_model.to(device)
else:
side_info_model = old_model
optimizer = get_optimizer(model, **config.get("optimizer_params"))
data = SubImageFolder(**config.get("dataset_params"))
lr_policy = get_policy(optimizer, **config.get("lr_policy_params"))
mus, criteria = get_criteria(**config.get("objective_params", {}))
trainer = TransformationTrainer(
old_model, new_model, side_info_model, **mus, **criteria
)
for epoch in range(config.get("epochs")):
lr_policy(epoch, iteration=None)
if config.get("switch_mode_to_eval"):
switch_mode_to_eval = epoch >= config.get("epochs") / 2
else:
switch_mode_to_eval = False
train_loss = trainer.train(
train_loader=data.train_loader,
model=model,
optimizer=optimizer,
device=device,
switch_mode_to_eval=switch_mode_to_eval,
)
print("Train: epoch = {}, Average Loss = {}".format(epoch, train_loss))
# evaluate on validation set
test_loss = trainer.validate(
val_loader=data.val_loader,
model=model,
device=device,
)
print("Test: epoch = {}, Average Loss = {}".format(epoch, test_loss))
transformation_to_torchscripts(
old_model,
side_info_model,
model,
config.get("output_transformation_path"),
config.get("output_transformed_old_model_path"),
)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to config file for this pipeline.",
)
args = parser.parse_args()
with open(args.config) as f:
read_config = yaml.safe_load(f)
main(read_config)