forked from baowenbo/DAIN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_args.py
118 lines (94 loc) · 6.07 KB
/
my_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import datetime
import argparse
import numpy
import networks
import torch
modelnames = networks.__all__
# import datasets
datasetNames = ('Vimeo_90K_interp') #datasets.__all__
parser = argparse.ArgumentParser(description='DAIN')
parser.add_argument('--debug',action = 'store_true', help='Enable debug mode')
parser.add_argument('--netName', type=str, default='DAIN',
choices = modelnames,help = 'model architecture: ' +
' | '.join(modelnames) +
' (default: DAIN)')
parser.add_argument('--datasetName', default='Vimeo_90K_interp',
choices= datasetNames,nargs='+',
help='dataset type : ' +
' | '.join(datasetNames) +
' (default: Vimeo_90K_interp)')
parser.add_argument('--datasetPath',default='',help = 'the path of selected datasets')
parser.add_argument('--dataset_split', type = int, default=97, help = 'Split a dataset into trainining and validation by percentage (default: 97)')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--numEpoch', '-e', type = int, default=100, help= 'Number of epochs to train(default:150)')
parser.add_argument('--batch_size', '-b',type = int ,default=1, help = 'batch size (default:1)' )
parser.add_argument('--workers', '-w', type =int,default=8, help = 'parallel workers for loading training samples (default : 1.6*10 = 16)')
parser.add_argument('--channels', '-c', type=int,default=3,choices = [1,3], help ='channels of images (default:3)')
parser.add_argument('--filter_size', '-f', type=int, default=4, help = 'the size of filters used (default: 4)',
choices=[2,4,6, 5,51]
)
parser.add_argument('--lr', type =float, default= 0.002, help= 'the basic learning rate for three subnetworks (default: 0.002)')
parser.add_argument('--rectify_lr', type=float, default=0.001, help = 'the learning rate for rectify/refine subnetworks (default: 0.001)')
parser.add_argument('--save_which', '-s', type=int, default=1, choices=[0,1], help='choose which result to save: 0 ==> interpolated, 1==> rectified')
parser.add_argument('--time_step', type=float, default=0.5, help='choose the time steps')
parser.add_argument('--flow_lr_coe', type = float, default=0.01, help = 'relative learning rate w.r.t basic learning rate (default: 0.01)')
parser.add_argument('--occ_lr_coe', type = float, default=1.0, help = 'relative learning rate w.r.t basic learning rate (default: 1.0)')
parser.add_argument('--filter_lr_coe', type = float, default=1.0, help = 'relative learning rate w.r.t basic learning rate (default: 1.0)')
parser.add_argument('--ctx_lr_coe', type = float, default=1.0, help = 'relative learning rate w.r.t basic learning rate (default: 1.0)')
parser.add_argument('--depth_lr_coe', type = float, default=0.001, help = 'relative learning rate w.r.t basic learning rate (default: 0.01)')
# parser.add_argument('--deblur_lr_coe', type = float, default=0.01, help = 'relative learning rate w.r.t basic learning rate (default: 0.01)')
parser.add_argument('--alpha', type=float,nargs='+', default=[0.0, 1.0], help= 'the ration of loss for interpolated and rectified result (default: [0.0, 1.0])')
parser.add_argument('--epsilon', type = float, default=1e-6, help = 'the epsilon for charbonier loss,etc (default: 1e-6)')
parser.add_argument('--weight_decay', type = float, default=0, help = 'the weight decay for whole network ' )
parser.add_argument('--patience', type=int, default=5, help = 'the patience of reduce on plateou')
parser.add_argument('--factor', type = float, default=0.2, help = 'the factor of reduce on plateou')
#
parser.add_argument('--pretrained', dest='SAVED_MODEL', default=None, help ='path to the pretrained model weights')
parser.add_argument('--no-date', action='store_true', help='don\'t append date timestamp to folder' )
parser.add_argument('--use_cuda', default= True, type = bool, help='use cuda or not')
parser.add_argument('--use_cudnn',default=1,type=int, help = 'use cudnn or not')
parser.add_argument('--dtype', default=torch.cuda.FloatTensor, choices = [torch.cuda.FloatTensor,torch.FloatTensor],help = 'tensor data type ')
# parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)')
parser.add_argument('--uid', type=str, default= None, help='unique id for the training')
parser.add_argument('--force', action='store_true', help='force to override the given uid')
args = parser.parse_args()
import shutil
if args.uid == None:
unique_id = str(numpy.random.randint(0, 100000))
print("revise the unique id to a random numer " + str(unique_id))
args.uid = unique_id
timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H:%M")
save_path = './model_weights/'+ args.uid +'-' + timestamp
else:
save_path = './model_weights/'+ str(args.uid)
# print("no pth here : " + save_path + "/best"+".pth")
if not os.path.exists(save_path + "/best"+".pth"):
# print("no pth here : " + save_path + "/best" + ".pth")
os.makedirs(save_path,exist_ok=True)
else:
if not args.force:
raise("please use another uid ")
else:
print("override this uid" + args.uid)
for m in range(1,10):
if not os.path.exists(save_path+"/log.txt.bk" + str(m)):
shutil.copy(save_path+"/log.txt", save_path+"/log.txt.bk"+str(m))
shutil.copy(save_path+"/args.txt", save_path+"/args.txt.bk"+str(m))
break
parser.add_argument('--save_path',default=save_path,help = 'the output dir of weights')
parser.add_argument('--log', default = save_path+'/log.txt', help = 'the log file in training')
parser.add_argument('--arg', default = save_path+'/args.txt', help = 'the args used')
args = parser.parse_args()
with open(args.log, 'w') as f:
f.close()
with open(args.arg, 'w') as f:
print(args)
print(args,file=f)
f.close()
if args.use_cudnn:
print("cudnn is used")
torch.backends.cudnn.benchmark = True # to speed up the
else:
print("cudnn is not used")
torch.backends.cudnn.benchmark = False # to speed up the