Skip to content

Commit

Permalink
Merge branch 'master' of github.com:THUDM/GATNE
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 committed Dec 16, 2019
2 parents 19d2194 + 4a44886 commit d499c37
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
45 changes: 24 additions & 21 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,23 @@ def generate_vocab(all_walks):

return vocab, index2word

def get_batches(pairs, batch_size):
def get_batches(pairs, neighbors, batch_size):
n_batches = (len(pairs) + (batch_size - 1)) // batch_size

result = []
# result = []
for idx in range(n_batches):
x, y, t = [], [], []
x, y, t, neigh = [], [], [], []
for i in range(batch_size):
index = idx * batch_size + i
if index >= len(pairs):
break
x.append(pairs[index][0])
y.append(pairs[index][1])
t.append(pairs[index][2])
result.append((np.array(x).astype(np.int32), np.array(y).reshape(-1, 1).astype(np.int32), np.array(t).astype(np.int32)))
return result
neigh.append(neighbors[pairs[index][0]])
# result.append((np.array(x).astype(np.int32), np.array(y).reshape(-1, 1).astype(np.int32), np.array(t).astype(np.int32), np.array(neigh).astype(np.int32)))
yield (np.array(x).astype(np.int32), np.array(y).reshape(-1, 1).astype(np.int32), np.array(t).astype(np.int32), np.array(neigh).astype(np.int32))
# return result

def generate_walks(network_data):
if args.schema is not None:
Expand All @@ -177,7 +179,7 @@ def generate_walks(network_data):

all_walks.append(layer_walks)

print('finish generating the walks')
print('Finish generating the walks')

return all_walks

Expand Down Expand Up @@ -230,8 +232,8 @@ def train_model(network_data, feature_dic, log_name):
print('feature dimension: ' + str(feature_dim))
features = np.zeros((num_nodes, feature_dim), dtype=np.float32)
for key, value in feature_dic.items():
if key in index2word:
features[index2word.index(key), :] = np.array(value)
if key in vocab:
features[vocab[key].index, :] = np.array(value)

with graph.as_default():
global_step = tf.Variable(0, name='global_step', trainable=False)
Expand All @@ -253,12 +255,13 @@ def train_model(network_data, feature_dic, log_name):
nce_weights = tf.Variable(tf.truncated_normal([num_nodes, embedding_size], stddev=1.0 / math.sqrt(embedding_size)))
nce_biases = tf.Variable(tf.zeros([num_nodes]))

node_neighbors = tf.Variable(neighbors, trainable=False)
# node_neighbors = tf.Variable(neighbors, trainable=False)

# Input data
train_inputs = tf.placeholder(tf.int32, shape=[None])
train_labels = tf.placeholder(tf.int32, shape=[None, 1])
train_types = tf.placeholder(tf.int32, shape=[None])
node_neigh = tf.placeholder(tf.int32, shape=[None, edge_type_count, neighbor_samples])

# Look up embeddings for nodes
if feature_dic is not None:
Expand All @@ -267,7 +270,7 @@ def train_model(network_data, feature_dic, log_name):
else:
node_embed = tf.nn.embedding_lookup(node_embeddings, train_inputs)

node_neigh = tf.nn.embedding_lookup(node_neighbors, train_inputs)
# node_neigh = tf.nn.embedding_lookup(node_neighbors, train_inputs)
if feature_dic is not None:
node_embed_neighbors = tf.nn.embedding_lookup(node_features, node_neigh)
node_embed_tmp = tf.concat([tf.matmul(tf.reshape(tf.slice(node_embed_neighbors, [0, i, 0, 0], [-1, 1, -1, -1]), [-1, feature_dim]), tf.reshape(tf.slice(u_embed_trans, [i, 0, 0], [1, -1, -1]), [feature_dim, embedding_u_size])) for i in range(edge_type_count)], axis=0)
Expand Down Expand Up @@ -320,25 +323,25 @@ def train_model(network_data, feature_dic, log_name):
sess.run(init)

print('Training')
iter = 0
g_iter = 0
best_score = 0
patience = 0
for epoch in range(epochs):
random.shuffle(train_pairs)
batches = get_batches(train_pairs, batch_size)
batches = get_batches(train_pairs, neighbors, batch_size)

data_iter = tqdm.tqdm(enumerate(batches),
desc="EP:%d" % (epoch),
total=len(batches),
data_iter = tqdm.tqdm(batches,
desc="epoch %d" % (epoch),
total=(len(train_pairs) + (batch_size - 1)) // batch_size,
bar_format="{l_bar}{r_bar}")
avg_loss = 0.0

for i, data in data_iter:
feed_dict = {train_inputs: data[0], train_labels: data[1], train_types: data[2]}
for i, data in enumerate(data_iter):
feed_dict = {train_inputs: data[0], train_labels: data[1], train_types: data[2], node_neigh: data[3]}
_, loss_value, summary_str = sess.run([optimizer, loss, merged], feed_dict)
writer.add_summary(summary_str, iter)
writer.add_summary(summary_str, g_iter)

iter += 1
g_iter += 1

avg_loss += loss_value

Expand All @@ -354,7 +357,7 @@ def train_model(network_data, feature_dic, log_name):
final_model = dict(zip(edge_types[:-1], [dict() for _ in range(edge_type_count)]))
for i in range(edge_type_count):
for j in range(num_nodes):
final_model[edge_types[i]][index2word[j]] = np.array(sess.run(last_node_embed, {train_inputs: [j], train_types: [i]})[0])
final_model[edge_types[i]][index2word[j]] = np.array(sess.run(last_node_embed, {train_inputs: [j], train_types: [i], node_neigh: [neighbors[j]]})[0])
valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count):
Expand All @@ -369,7 +372,7 @@ def train_model(network_data, feature_dic, log_name):
test_f1s.append(tmp_f1)
test_prs.append(tmp_pr)
print('valid auc:', np.mean(valid_aucs))
print('valid pr', np.mean(valid_prs))
print('valid pr:', np.mean(valid_prs))
print('valid f1:', np.mean(valid_f1s))

average_auc = np.mean(test_aucs)
Expand Down
8 changes: 3 additions & 5 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_G_from_edges(edges):
return tmp_G

def load_training_data(f_name):
print('We are loading training data from:', f_name)
print('We are loading data from:', f_name)
edge_data_by_type = dict()
all_edges = list()
all_nodes = list()
Expand All @@ -36,13 +36,12 @@ def load_training_data(f_name):
all_nodes = list(set(all_nodes))
all_edges = list(set(all_edges))
edge_data_by_type['Base'] = all_edges
print('total training nodes: ' + str(len(all_nodes)))
print('Finish loading training data')
print('Total training nodes: ' + str(len(all_nodes)))
return edge_data_by_type


def load_testing_data(f_name):
print('We are loading testing data from:', f_name)
print('We are loading data from:', f_name)
true_edge_data_by_type = dict()
false_edge_data_by_type = dict()
all_edges = list()
Expand All @@ -62,7 +61,6 @@ def load_testing_data(f_name):
all_nodes.append(x)
all_nodes.append(y)
all_nodes = list(set(all_nodes))
print('Finish loading testing data')
return true_edge_data_by_type, false_edge_data_by_type

def load_node_type(f_name):
Expand Down

0 comments on commit d499c37

Please sign in to comment.