-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_complex_WN18RR.py
46 lines (40 loc) · 1.26 KB
/
train_complex_WN18RR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import openke
from openke.config import Trainer, Tester
from openke.module.model import ComplEx
from openke.module.loss import SoftplusLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader
# dataloader for training
train_dataloader = TrainDataLoader(
in_path = "./benchmarks/WN18RR/",
nbatches = 100,
threads = 8,
sampling_mode = "normal",
bern_flag = 1,
filter_flag = 1,
neg_ent = 25,
neg_rel = 0
)
# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link")
# define the model
complEx = ComplEx(
ent_tot = train_dataloader.get_ent_tot(),
rel_tot = train_dataloader.get_rel_tot(),
dim = 200
)
# define the loss function
model = NegativeSampling(
model = complEx,
loss = SoftplusLoss(),
batch_size = train_dataloader.get_batch_size(),
regul_rate = 1.0
)
# train the model
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 2000, alpha = 0.5, use_gpu = True, opt_method = "adagrad")
trainer.run()
complEx.save_checkpoint('./checkpoint/complEx.ckpt')
# test the model
complEx.load_checkpoint('./checkpoint/complEx.ckpt')
tester = Tester(model = complEx, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)