-
Notifications
You must be signed in to change notification settings - Fork 143
/
Copy pathgraph_layer.py
124 lines (87 loc) · 4.02 KB
/
graph_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch.nn import Parameter, Linear, Sequential, BatchNorm1d, ReLU
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros
import time
import math
class GraphLayer(MessagePassing):
def __init__(self, in_channels, out_channels, heads=1, concat=True,
negative_slope=0.2, dropout=0, bias=True, inter_dim=-1,**kwargs):
super(GraphLayer, self).__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.__alpha__ = None
self.lin = Linear(in_channels, heads * out_channels, bias=False)
self.att_i = Parameter(torch.Tensor(1, heads, out_channels))
self.att_j = Parameter(torch.Tensor(1, heads, out_channels))
self.att_em_i = Parameter(torch.Tensor(1, heads, out_channels))
self.att_em_j = Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.lin.weight)
glorot(self.att_i)
glorot(self.att_j)
zeros(self.att_em_i)
zeros(self.att_em_j)
zeros(self.bias)
def forward(self, x, edge_index, embedding, return_attention_weights=False):
""""""
if torch.is_tensor(x):
x = self.lin(x)
x = (x, x)
else:
x = (self.lin(x[0]), self.lin(x[1]))
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index,
num_nodes=x[1].size(self.node_dim))
out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index,
return_attention_weights=return_attention_weights)
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.bias is not None:
out = out + self.bias
if return_attention_weights:
alpha, self.__alpha__ = self.__alpha__, None
return out, (edge_index, alpha)
else:
return out
def message(self, x_i, x_j, edge_index_i, size_i,
embedding,
edges,
return_attention_weights):
x_i = x_i.view(-1, self.heads, self.out_channels)
x_j = x_j.view(-1, self.heads, self.out_channels)
if embedding is not None:
embedding_i, embedding_j = embedding[edge_index_i], embedding[edges[0]]
embedding_i = embedding_i.unsqueeze(1).repeat(1,self.heads,1)
embedding_j = embedding_j.unsqueeze(1).repeat(1,self.heads,1)
key_i = torch.cat((x_i, embedding_i), dim=-1)
key_j = torch.cat((x_j, embedding_j), dim=-1)
cat_att_i = torch.cat((self.att_i, self.att_em_i), dim=-1)
cat_att_j = torch.cat((self.att_j, self.att_em_j), dim=-1)
alpha = (key_i * cat_att_i).sum(-1) + (key_j * cat_att_j).sum(-1)
alpha = alpha.view(-1, self.heads, 1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index_i, size_i)
if return_attention_weights:
self.__alpha__ = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * alpha.view(-1, self.heads, 1)
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)