-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
67 lines (54 loc) · 2.66 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import argparse
import random
import numpy as np
import os
from trainer.ConditionalTrainer import ConditionalNPTrainer as Trainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--test_model', choices=['NODE', 'NP'], default='NODE', help='NP = transformer for both encoder and decoder')
parser.add_argument('--model_type', choices=['FNODEs', 'FNP', 'NP', 'NODEs'], default='FNODEs')
parser.add_argument('--NP', action='store_true')
parser.add_argument('--encoder', choices=['Conv'], default='Conv')
parser.add_argument('--decoder', choices=['Fourier', 'ODE', 'NP', 'Transformer', 'RNN'])
# Encoder
parser.add_argument('--encoder_hidden_dim', type=int, default=32)
parser.add_argument('--encoder_blocks', type=int, default=3)
parser.add_argument('--decoder_layers', type=int, default=2)
parser.add_argument('--decoder_hidden_dim', type=int, default=256)
# Decoder
parser.add_argument('--in_features', type=int, default=1)
parser.add_argument('--out_features', type=int, default=1)
parser.add_argument('--latent_dimension', type=int, default=3, help='dimension for NP')
parser.add_argument('--expfunc', type=str, default='fourier')
parser.add_argument('--n_harmonics', type=int, default=1)
parser.add_argument('--n_eig', type=int, default=2)
parser.add_argument('--lower_bound', type=float, default=1)
parser.add_argument('--upper_bound', type=float)
parser.add_argument('--skip_step', type=int)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--n_epochs', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--path', type=str, default='./', help='parameter saving path')
parser.add_argument('--dataset_path', type=str, default='./input/')
parser.add_argument('--dataset_name', type=str)
parser.add_argument('--dataset_type', choices=['sin', 'ECG'])
parser.add_argument('--device_num', type=str, default='0')
args = parser.parse_args()
if args.dataset_type == 'sin':
args.num_label = 4
elif args.dataset_type == 'ECG':
args.num_label = 3
assert ((args.upper_bound - args.lower_bound + 1) == args.n_harmonics), "the number of harmonics and lower and upper bound should match"
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
trainer = Trainer(args)
trainer.train()
if __name__ == '__main__':
main()