-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain_gana.py
79 lines (68 loc) · 2.7 KB
/
main_gana.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from trainer_gana import *
from params import *
from data_loader import *
import json
if __name__ == '__main__':
params = get_params()
print("---------Parameters---------")
for k, v in params.items():
print(k + ': ' + str(v))
print("----------------------------")
# control random seed
if params['seed'] is not None:
SEED = params['seed']
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
np.random.seed(SEED)
random.seed(SEED)
# select the dataset
for k, v in data_dir.items():
data_dir[k] = params['data_path']+v
tail = '_in_train'
#if params['data_form'] == 'In-Train':
# tail = '_in_train'
dataset = dict()
print("loading train_tasks{} ... ...".format(tail))
dataset['train_tasks'] = json.load(open(data_dir['train_tasks'+tail]))
print("loading test_tasks ... ...")
dataset['test_tasks'] = json.load(open(data_dir['test_tasks']))
print("loading dev_tasks ... ...")
dataset['dev_tasks'] = json.load(open(data_dir['dev_tasks']))
print("loading rel2candidates{} ... ...".format(tail))
dataset['rel2candidates'] = json.load(open(data_dir['rel2candidates'+tail]))
print("loading e1rel_e2{} ... ...".format(tail))
dataset['e1rel_e2'] = json.load(open(data_dir['e1rel_e2'+tail]))
print("loading ent2id ... ...")
dataset['ent2id'] = json.load(open(data_dir['ent2ids']))
dataset['rel2id'] = json.load(open(data_dir['rel2ids']))
if params['data_form'] == 'Pre-Train':
print('loading embedding ... ...')
dataset['ent2emb'] = np.loadtxt(params['data_path']+'/entity2vec.TransE')
dataset['rel2emb'] = np.loadtxt(params['data_path']+'/relation2vec.TransE')
print("----------------------------")
# data_loader
train_data_loader = DataLoader(dataset, params, step='train')
dev_data_loader = DataLoader(dataset, params, step='dev')
test_data_loader = DataLoader(dataset, params, step='test')
data_loaders = [train_data_loader, dev_data_loader, test_data_loader]
# trainer
trainer = Trainer(data_loaders, dataset, params)
if params['step'] == 'train':
trainer.train()
print("test")
print(params['prefix'])
trainer.reload()
trainer.eval(istest=True)
elif params['step'] == 'test':
print(params['prefix'])
if params['eval_by_rel']:
trainer.eval_by_relation(istest=True)
else:
trainer.eval(istest=True)
elif params['step'] == 'dev':
print(params['prefix'])
if params['eval_by_rel']:
trainer.eval_by_relation(istest=False)
else:
trainer.eval(istest=False)