-
Notifications
You must be signed in to change notification settings - Fork 0
/
GAT.py
98 lines (94 loc) · 5.7 KB
/
GAT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from torch import nn
#代码来源:https://zhuanlan.zhihu.com/p/128072201
# 图注意力网络
class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
图注意力层
"""
def __init__(self, in_features, out_features, dropout, alpha, concat = True):
super(GraphAttentionLayer, self).__init__()
self.in_features = in_features # 节点表示向量的输入特征数
self.out_features = out_features # 节点表示向量的输出特征数
self.dropout = dropout # dropout参数
self.alpha = alpha # leakyrelu激活的参数
self.concat = concat # 如果为true, 再进行elu激活
self.elu = nn.ELU()
# 定义可训练参数,即论文中的W和a
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414) # 初始化
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414) # 初始化
# 定义leakyrelu激活函数
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, inp, adj):
"""
inp: input_fea [N, in_features] in_features表示节点的输入特征向量元素个数
adj: 图的邻接矩阵 [N, N] 非零即一,数据结构基本知识"""
h = torch.mm(inp, self.W) # [N, out_features]
N = h.size()[0] # N 图的节点数
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, N, 2 * self.out_features)
# [N, N, 2*out_features]
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
# [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)
zero_vec = -1e12 * torch.ones_like(e) # 将没有连接的边置为负无穷
attention = torch.where(adj > 0, e, zero_vec) # [N, N]
# 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留,否则需要mask并置为非常小的值,原因是softmax的时候这个最小值会不考虑。
attention = nn.softmax(attention, dim=1) # softmax形状保持不变 [N, N],得到归一化的注意力权重!
attention = nn.dropout(attention, self.dropout, training=self.training) # dropout,防止过拟合
h_prime = torch.matmul(attention, h) # [N, N].[N, out_features] => [N, out_features]
# 得到由周围节点通过注意力权重进行更新的表示
if self.concat:
return self.elu(h_prime)
else:
return h_prime
def forward(self, inp):
"""
inp: input_fea [N, in_features] in_features表示节点的输入特征向量元素个数
adj: 图的邻接矩阵 [N, N] 非零即一,数据结构基本知识"""
h = torch.mm(inp, self.W) # [N, out_features]
N = h.size()[0] # N 图的节点数
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, N, 2 * self.out_features)
# [N, N, 2*out_features]
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
# [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)
# 实际上,此处的Dropout也可以用其他方法代替
attention = self.dropout(e, self.dropout, training=self.training) # dropout,防止过拟合
h_prime = torch.matmul(attention, h) # [N, N].[N, out_features] => [N, out_features]
# 得到由周围节点通过注意力权重进行更新的表示
if self.concat:
return self.elu(h_prime)
else:
return h_prime
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
# Muti-Head图注意力网络
class GAT(nn.Module):
def __init__(self, n_feat, n_hid, n_class, dropout, alpha, n_heads):
"""Dense version of GAT
n_heads 表示有几个GAL层,最后进行拼接在一起,类似self-attention
从不同的子空间进行抽取特征。
"""
super(GAT, self).__init__()
self.dropout = dropout
self.elu = nn.ELU()
# 定义multi-head的图注意力层
self.attentions = [GraphAttentionLayer(n_feat, n_hid, dropout=dropout, alpha=alpha, concat=True) for _ in range(n_heads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention) # 加入pytorch的Module模块
# 输出层,也通过图注意力层来实现,可实现分类、预测等功能
self.out_att = GraphAttentionLayer(n_hid * n_heads, n_class, dropout=dropout,alpha=alpha, concat=False)
self.log_softmax = nn.LogSoftmax()
def forward(self, x, adj):
x = nn.dropout(x, self.dropout, training=self.training) # dropout,防止过拟合
x = torch.cat([att(x, adj) for att in self.attentions], dim=1) # 将每个head得到的表示进行拼接
x = nn.dropout(x, self.dropout, training=self.training) # dropout,防止过拟合
x = self.elu(self.out_att(x, adj)) # 输出并激活
return self.log_softmax(x, dim=1) # log_softmax速度变快,保持数值稳定
def forward(self, x):
x = nn.dropout(x, self.dropout, training=self.training) # dropout,防止过拟合
x = torch.cat([att(x) for att in self.attentions], dim=1) # 将每个head得到的表示进行拼接
x = nn.dropout(x, self.dropout, training=self.training) # dropout,防止过拟合
x = self.elu(self.out_att(x)) # 输出并激活
return self.log_softmax(x, dim=1) # log_softmax速度变快,保持数值稳定