-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathpretrain.py
71 lines (57 loc) · 2.39 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import sys
import math
import pprint
import torch
from torchdrug import core, models, tasks, datasets, utils
from torchdrug.utils import comm
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import util
from gearnet import model, layer
def save(solver, path):
if isinstance(solver.model, tasks.Unsupervised):
model = solver.model.model.model
else:
model = solver.model.model
if comm.get_rank() == 0:
logger.warning("Save checkpoint to %s" % path)
path = os.path.expanduser(path)
if comm.get_rank() == 0:
torch.save(model.state_dict(), path)
comm.synchronize()
if __name__ == "__main__":
args, vars = util.parse_args()
cfg = util.load_config(args.config, context=vars)
working_dir = util.create_working_directory(cfg)
torch.manual_seed(args.seed + comm.get_rank())
logger = util.get_root_logger()
if comm.get_rank() == 0:
logger.warning("Config file: %s" % args.config)
logger.warning(pprint.pformat(cfg))
species_start = cfg.dataset.get("species_start", 0)
species_end = cfg.dataset.get("species_end", 0)
assert species_end >= species_start
if species_end > species_start:
cfg.dataset.species_id = species_start
cfg.dataset.split_id = 0
cfg.dataset.pop("species_start")
cfg.dataset.pop("species_end")
dataset = core.Configurable.load_config_dict(cfg.dataset)
solver = util.build_pretrain_solver(cfg, dataset)
step = cfg.get("save_interval", 1)
for i in range(0, cfg.train.num_epoch, step):
kwargs = cfg.train.copy()
kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i)
if species_end == species_start:
solver.train(**kwargs)
else:
for species_id in range(species_start, species_end):
for split_id in range(dataset.species_nsplit[species_id]):
cfg.dataset.species_id = species_id
cfg.dataset.split_id = split_id
dataset = core.Configurable.load_config_dict(cfg.dataset)
logger.warning('Epoch: {}\tSpecies id: {}\tSplit id: {}\tSplit length: {}'.format(
i, species_id, split_id, len(dataset)))
solver.train_set = dataset
solver.train(**kwargs)
save(solver, "model_epoch_%d.pth" % (i + kwargs["num_epoch"]))