-
-
Notifications
You must be signed in to change notification settings - Fork 40
/
dir_interp.py
80 lines (61 loc) · 2.55 KB
/
dir_interp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import random
import torch
import argparse
from collections import OrderedDict
import os
import os.path
import sys
####################################################################################################################################
# Average all .pth models in a defined path ('-intdir') and save the result in a defined destination (-savepath)
####################################################################################################################################
MODEL_EXTENSIONS = ['.pth']
def is_model_file(filename):
return any(filename.endswith(extension) for extension in MODEL_EXTENSIONS)
def _get_paths_from_models(path):
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
model_list = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_model_file(fname):
model_path = os.path.join(dirpath, fname)
model_list.append(model_path)
assert model_list, '{:s} has no valid model file'.format(path)
return model_list
def main(args):
if args.savepath:
net_interp_path = args.savepath
else:
net_interp_path = '../../experiments/pretrained_models/dirinterp.pth'
if args.intdir:
model_list = _get_paths_from_models(args.intdir)
#print(model_list)
net_interp = OrderedDict()
i = 0
for path in model_list:
if i == 0:
net = torch.load(path)
net_interp = net
i += 1
print(str(path)+" added.")
continue
net = torch.load(path)
for k, v_netA in net.items():
if k in net_interp:
v_netB = net_interp[k]
net_interp[k] = v_netA + v_netB
i += 1
print(str(path)+" added.")
print(str(i)+" models combined")
for k, v_net in net_interp.items():
net_interp[k] = v_net/i
torch.save(net_interp, net_interp_path)
print('model saved in: ', net_interp_path)
else:
print('No directory defined')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-savepath', '-p', type=str, required=False, help='Path and filename for new model') # Option to set the save path
parser.add_argument('-intdir', type=str, required=False, help='Directory to combine models')
args = parser.parse_args()
#print(args)
main(args)