Skip to content

Commit

Permalink
Add PyTorch version
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 committed Dec 30, 2019
1 parent 8954d51 commit 3a71c6f
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 6 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ Accepted to KDD 2019 Research Track!

## Prerequisites

- Linux or macOS
- Python 3
- TensorFlow GPU >= 1.8
- NVIDIA GPU + CUDA cuDNN
- TensorFlow >= 1.8 (or PyTorch)

## Getting Started

Expand Down Expand Up @@ -45,7 +43,7 @@ These datasets are sampled from the original datasets.

#### Training on the existing datasets

You can use `./scripts/run_example.sh` or `python src/main.py --input data/example` to train GATNE-T model on the example data. (If you share the server with others or you want to use the specific GPU(s), you may need to set `CUDA_VISIBLE_DEVICES`.)
You can use `./scripts/run_example.sh` or `python src/main.py --input data/example` or `python src/main_pytorch.py --input data/example` to train GATNE-T model on the example data. (If you share the server with others or you want to use the specific GPU(s), you may need to set `CUDA_VISIBLE_DEVICES`.)

If you want to train on the Amazon dataset, you can run `python src/main.py --input data/amazon` or `python src/main.py --input data/amazon --features data/amazon/feature.txt` to train GATNE-T model or GATNE-I model, respectively.

Expand Down
2 changes: 0 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def get_batches(pairs, neighbors, batch_size):

def train_model(network_data, feature_dic, log_name):
all_walks = generate_walks(network_data, args.num_walks, args.walk_length, args.schema, file_name)

vocab, index2word = generate_vocab(all_walks)

train_pairs = generate_pairs(all_walks, vocab, args.window_size)

edge_types = list(network_data.keys())
Expand Down
295 changes: 295 additions & 0 deletions src/main_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
import math
import os
import sys
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from numpy import random
from torch.nn.parameter import Parameter

from utils import *


def get_batches(pairs, neighbors, batch_size):
n_batches = (len(pairs) + (batch_size - 1)) // batch_size

for idx in range(n_batches):
x, y, t, neigh = [], [], [], []
for i in range(batch_size):
index = idx * batch_size + i
if index >= len(pairs):
break
x.append(pairs[index][0])
y.append(pairs[index][1])
t.append(pairs[index][2])
neigh.append(neighbors[pairs[index][0]])
yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh)


class GATNEModel(nn.Module):
def __init__(
self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a
):
super(GATNEModel, self).__init__()
self.num_nodes = num_nodes
self.embedding_size = embedding_size
self.embedding_u_size = embedding_u_size
self.edge_type_count = edge_type_count
self.dim_a = dim_a

self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.node_type_embeddings = Parameter(
torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
)
self.trans_weights = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
)
self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
)
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))

self.reset_parameters()

def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

def forward(self, train_inputs, train_types, node_neigh):
node_embed = self.node_embeddings[train_inputs]
node_embed_neighbors = self.node_type_embeddings[node_neigh]
node_embed_tmp = torch.cat(
[
node_embed_neighbors[:, i, :, i, :].unsqueeze(1)
for i in range(self.edge_type_count)
],
dim=1,
)
node_type_embed = torch.sum(node_embed_tmp, dim=2)

trans_w = self.trans_weights[train_types]
trans_w_s1 = self.trans_weights_s1[train_types]
trans_w_s2 = self.trans_weights_s2[train_types]

attention = F.softmax(
torch.matmul(
torch.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2
).squeeze(2),
dim=1,
).unsqueeze(1)
node_type_embed = torch.matmul(attention, node_type_embed)
node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze(1)

last_node_embed = F.normalize(node_embed, dim=1)

return last_node_embed


class NSLoss(nn.Module):
def __init__(self, num_nodes, num_sampled, embedding_size):
super(NSLoss, self).__init__()
self.num_nodes = num_nodes
self.num_sampled = num_sampled
self.embedding_size = embedding_size
self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.sample_weights = F.normalize(
torch.Tensor(
[
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
for k in range(num_nodes)
]
),
dim=0,
)

self.reset_parameters()

def reset_parameters(self):
self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

def forward(self, input, embs, label):
n = input.shape[0]
log_target = torch.log(
torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
)
negs = torch.multinomial(
self.sample_weights, self.num_sampled * n, replacement=True
).view(n, self.num_sampled)
noise = torch.neg(self.weights[negs])
sum_log_sampled = torch.sum(
torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
).squeeze()

loss = log_target + sum_log_sampled
return -loss.sum() / n


def train_model(network_data):
all_walks = generate_walks(network_data, args.num_walks, args.walk_length, args.schema, file_name)
vocab, index2word = generate_vocab(all_walks)
train_pairs = generate_pairs(all_walks, vocab, args.window_size)

edge_types = list(network_data.keys())

num_nodes = len(index2word)
edge_type_count = len(edge_types)
epochs = args.epoch
batch_size = args.batch_size
embedding_size = args.dimensions
embedding_u_size = args.edge_dim
u_num = edge_type_count
num_sampled = args.negative_samples
dim_a = args.att_dim
att_head = 1
neighbor_samples = args.neighbor_samples

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)]
for r in range(edge_type_count):
g = network_data[edge_types[r]]
for (x, y) in g:
ix = vocab[x].index
iy = vocab[y].index
neighbors[ix][r].append(iy)
neighbors[iy][r].append(ix)
for i in range(num_nodes):
if len(neighbors[i][r]) == 0:
neighbors[i][r] = [i] * neighbor_samples
elif len(neighbors[i][r]) < neighbor_samples:
neighbors[i][r].extend(
list(
np.random.choice(
neighbors[i][r],
size=neighbor_samples - len(neighbors[i][r]),
)
)
)
elif len(neighbors[i][r]) > neighbor_samples:
neighbors[i][r] = list(
np.random.choice(neighbors[i][r], size=neighbor_samples)
)

model = GATNEModel(
num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a
)
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)

model.to(device)
nsloss.to(device)

optimizer = torch.optim.Adam(
[{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4
)

best_score = 0
patience = 0
for epoch in range(epochs):
random.shuffle(train_pairs)
batches = get_batches(train_pairs, neighbors, batch_size)

data_iter = tqdm.tqdm(
batches,
desc="epoch %d" % (epoch),
total=(len(train_pairs) + (batch_size - 1)) // batch_size,
bar_format="{l_bar}{r_bar}",
)
avg_loss = 0.0

for i, data in enumerate(data_iter):
optimizer.zero_grad()
embs = model(data[0].to(device), data[2].to(device), data[3].to(device),)
loss = nsloss(data[0].to(device), embs, data[1].to(device))
loss.backward()
optimizer.step()

avg_loss += loss.item()

if i % 5000 == 0:
post_fix = {
"epoch": epoch,
"iter": i,
"avg_loss": avg_loss / (i + 1),
"loss": loss.item(),
}
data_iter.write(str(post_fix))

final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
for i in range(num_nodes):
train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(device)
train_types = torch.tensor(list(range(edge_type_count))).to(device)
node_neigh = torch.tensor(
[neighbors[i] for _ in range(edge_type_count)]
).to(device)
node_emb = model(train_inputs, train_types, node_neigh)
for j in range(edge_type_count):
final_model[edge_types[j]][index2word[i]] = (
node_emb[j].cpu().detach().numpy()
)

valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","):
tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]],
valid_true_data_by_edge[edge_types[i]],
valid_false_data_by_edge[edge_types[i]],
)
valid_aucs.append(tmp_auc)
valid_f1s.append(tmp_f1)
valid_prs.append(tmp_pr)

tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]],
testing_true_data_by_edge[edge_types[i]],
testing_false_data_by_edge[edge_types[i]],
)
test_aucs.append(tmp_auc)
test_f1s.append(tmp_f1)
test_prs.append(tmp_pr)
print("valid auc:", np.mean(valid_aucs))
print("valid pr:", np.mean(valid_prs))
print("valid f1:", np.mean(valid_f1s))

average_auc = np.mean(test_aucs)
average_f1 = np.mean(test_f1s)
average_pr = np.mean(test_prs)

cur_score = np.mean(valid_aucs)
if cur_score > best_score:
best_score = cur_score
patience = 0
else:
patience += 1
if patience > args.patience:
print("Early Stopping")
break
return average_auc, average_f1, average_pr


if __name__ == "__main__":
args = parse_args()
file_name = args.input
print(args)

training_data_by_type = load_training_data(file_name + "/train.txt")
valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(
file_name + "/valid.txt"
)
testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(
file_name + "/test.txt"
)

average_auc, average_f1, average_pr = train_model(training_data_by_type)

print("Overall ROC-AUC:", average_auc)
print("Overall PR-AUC", average_pr)
print("Overall F1:", average_f1)

0 comments on commit 3a71c6f

Please sign in to comment.