-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4382efe
commit 466e41a
Showing
6 changed files
with
84 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
|
||
import torch | ||
import pytorch_lightning as pl | ||
import yaml | ||
from dataloader.link_pre_dataloader import LinkPredictionDataloader | ||
from models.LinkPreTask import LinkPredictionTask | ||
# 用来在晚上连续跑实验的工具 | ||
def get_trainer_model_dataloader_from_dir(settings): | ||
dl = LinkPredictionDataloader(**settings['data']) | ||
model = LinkPredictionTask(dl.edge_index, dl.edge_type, dl.feature_data, dl.N, **settings['model']) | ||
checkpoint_callback = pl.callbacks.ModelCheckpoint(**settings['callback']) | ||
trainer = pl.Trainer(callbacks=[checkpoint_callback], **settings['train']) | ||
return trainer, model, dl | ||
|
||
def plan(base_settings,model_replace_key,model_replace_values): | ||
''' | ||
:param base_settings: 基础配置 | ||
:param model_replace_key: 取代的超参 | ||
:param model_replace_values: 超参值的列表 | ||
:return: | ||
''' | ||
for v in model_replace_values: | ||
base_settings['model'][model_replace_key] = v | ||
print('--------------------------------------------------') | ||
print(model_replace_key, '=', v, 'has bean done!') | ||
trainer,model,dl=get_trainer_model_dataloader_from_dir(base_settings) | ||
trainer.fit(model,dl) | ||
# 测试 | ||
# 加载参数 | ||
ckpt_path = trainer.log_dir + '/checkpoints/' + os.listdir(trainer.log_dir + '/checkpoints')[0] | ||
state_dict = torch.load(ckpt_path)['state_dict'] | ||
model.load_state_dict(state_dict) | ||
trainer.test(model, dl.test_dataloader()) | ||
print(model_replace_key, '=', v, 'has finished! result in',trainer.log_dir) | ||
print('--------------------------------------------------') | ||
del trainer | ||
del model | ||
del dl | ||
print('finish plan!') | ||
|
||
if __name__ == '__main__': | ||
yaml_path = 'settings/yot_settings.yaml' | ||
key = 'L' | ||
values = [1,2,3,4,5,6] | ||
# key = 'lam' | ||
# values = [1.,0.5,0.3,0.05,0.01,0.001] | ||
with open(yaml_path) as f: | ||
settings = dict(yaml.load(f,yaml.FullLoader)) | ||
plan(settings,key,values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters