Skip to content

Commit

Permalink
Merge branch 'master' of github.com:THUDM/GATNE
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 committed Nov 6, 2019
2 parents 9634761 + 52b4ecd commit 401251a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ If you have ANY difficulties to get things working in the above steps, feel free
Please cite our paper if you find this code useful for your research:

```
@article{cen2019representation,
title={Representation Learning for Attributed Multiplex Heterogeneous Network},
author={Cen, Yukuo and Zou, Xu and Zhang, Jianwei and Yang, Hongxia and Zhou, Jingren and Tang, Jie},
journal={arXiv preprint arXiv:1905.01669},
year={2019}
@inproceedings{cen2019representation,
title = {Representation Learning for Attributed Multiplex Heterogeneous Network},
author = {Cen, Yukuo and Zou, Xu and Zhang, Jianwei and Yang, Hongxia and Zhou, Jingren and Tang, Jie},
booktitle = {Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining},
year = {2019},
pages = {1358--1368},
publisher = {ACM},
}
```
20 changes: 14 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,29 @@ def get_score(local_model, node1, node2):
vector2 = local_model[node2]
return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
except Exception as e:
print(e)
pass


def evaluate(model, true_edges, false_edges):
true_list = list()
prediction_list = list()
true_num = 0
for edge in true_edges:
tmp_score = get_score(model, str(edge[0]), str(edge[1]))
true_list.append(1)
prediction_list.append(tmp_score)
if tmp_score is not None:
true_list.append(1)
prediction_list.append(tmp_score)
true_num += 1

for edge in false_edges:
tmp_score = get_score(model, str(edge[0]), str(edge[1]))
true_list.append(0)
prediction_list.append(tmp_score)
if tmp_score is not None:
true_list.append(0)
prediction_list.append(tmp_score)

sorted_pred = prediction_list[:]
sorted_pred.sort()
threshold = sorted_pred[-len(true_edges)]
threshold = sorted_pred[-true_num]

y_pred = np.zeros(len(prediction_list), dtype=np.int32)
for i in range(len(prediction_list)):
Expand Down Expand Up @@ -193,6 +197,10 @@ def train_model(network_data, feature_dic, log_name):
train_pairs = generate_pairs(all_walks, vocab)

edge_types = list(network_data.keys())
if edge_types[-1] != 'Base':
edge_types.sort()
edge_types.remove('Base')
edge_types.append('Base')

num_nodes = len(index2word)
edge_type_count = len(edge_types) - 1
Expand Down

0 comments on commit 401251a

Please sign in to comment.