-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph_model.py
699 lines (609 loc) · 33.4 KB
/
graph_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import pandas as pd
import dgl
from dgl.nn import GraphConv
import numpy as np
import torch as th
from torch import nn
import dgl.function as fn
from dgl.nn.functional import edge_softmax
from dgl.base import DGLError
from dgl.nn.pytorch.utils import Identity
from dgl.utils import expand_as_pair
class WeightedGATConv(nn.Module):
r"""Graph attention layer from `Graph Attention Network
<https://arxiv.org/pdf/1710.10903.pdf>`__
.. math::
h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}
where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
node :math:`j`:
.. math::
\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})
e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)
Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
GATConv can be applied on homogeneous graph and unidirectional
`bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature. Defaults: ``0``.
attn_drop : float, optional
Dropout rate on attention weight. Defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope. Defaults: ``0.2``.
residual : bool, optional
If True, use residual connection. Defaults: ``False``.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
bias : bool, optional
If True, learns a bias term. Defaults: ``True``.
Note
----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zero-in-degree when use
after conv.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import GATConv
>>> # Case 1: Homogeneous graph
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> gatconv = GATConv(10, 2, num_heads=3)
>>> res = gatconv(g, feat)
>>> res
tensor([[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
>>> gatconv = GATConv((5,10), 2, 3)
>>> res = gatconv(g, (u_feat, v_feat))
>>> res
tensor([[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]],
[[ 0.0268, 1.0783],
[ 0.5041, -1.3025],
[ 0.6568, 0.7048]],
[[-0.2688, 1.0543],
[-0.0315, -0.9016],
[ 0.3943, 0.5347]],
[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True):
super(WeightedGATConv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)
else:
self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_feats,)))
else:
self.register_buffer('bias', None)
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.reset_parameters()
self.activation = activation
self.weight_linear = nn.Linear(1, 1)
def reset_parameters(self):
"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
if hasattr(self, 'fc'):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if self.bias is not None:
nn.init.constant_(self.bias, 0)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, get_attention=False):
r"""
Description
-----------
Compute graph attention network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
------
DGLError
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if isinstance(feat, tuple):
src_prefix_shape = feat[0].shape[:-1]
dst_prefix_shape = feat[1].shape[:-1]
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
feat_src = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
else:
feat_src = self.fc_src(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(
*dst_prefix_shape, self._num_heads, self._out_feats)
else:
src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# fuse the weights of edges
# print(graph.edata['weight'].shape, graph.edata['weight'].T.shape)
_w = self.weight_linear(graph.edata['weight'].unsqueeze(1))
# print(e.shape, _w.shape)
e = e * _w.unsqueeze(2) # (num_edges, num_heads, hidden_dim) * (num_edges, 1, 1)
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
# residual
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats)
rst = rst + resval
# bias
if self.bias is not None:
rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
class GATLayer(torch.nn.Module):
"""
Base class for all implementations as there is much code that would otherwise be copy/pasted.
"""
head_dim = 1
def __init__(self, num_in_features, num_out_features, num_of_heads, device, concat=True, activation=nn.ELU(),
dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False):
super().__init__()
# Saving these as we'll need them in forward propagation in children layers (imp1/2/3)
self.num_of_heads = num_of_heads
self.num_out_features = num_out_features
self.concat = concat # whether we should concatenate or average the attention heads
self.add_skip_connection = add_skip_connection
self.device = device
#
# Trainable weights: linear projection matrix (denoted as "W" in the paper), attention target/source
# (denoted as "a" in the paper) and bias (not mentioned in the paper but present in the official GAT repo)
#
# You can treat this one matrix as num_of_heads independent W matrices
self.linear_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)
# After we concatenate target node (node i) and source node (node j) we apply the additive scoring function
# which gives us un-normalized score "e". Here we split the "a" vector - but the semantics remain the same.
# Basically instead of doing [x, y] (concatenation, x/y are node feature vectors) and dot product with "a"
# we instead do a dot product between x and "a_left" and y and "a_right" and we sum them up
self.scoring_fn_target = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
self.scoring_fn_source = nn.Parameter(torch.Tensor(1, num_of_heads, num_out_features))
# Bias is definitely not crucial to GAT - feel free to experiment (I pinged the main author, Petar, on this one)
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(num_of_heads * num_out_features))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(num_out_features))
else:
self.register_parameter('bias', None)
if add_skip_connection:
self.skip_proj = nn.Linear(num_in_features, num_of_heads * num_out_features, bias=False)
else:
self.register_parameter('skip_proj', None)
#
# End of trainable weights
#
self.leakyReLU = nn.LeakyReLU(0.2) # using 0.2 as in the paper, no need to expose every setting
self.softmax = nn.Softmax(dim=-1) # -1 stands for apply the log-softmax along the last dimension
self.activation = activation
# Probably not the nicest design but I use the same module in 3 locations, before/after features projection
# and for attention coefficients. Functionality-wise it's the same as using independent modules.
self.dropout = nn.Dropout(p=dropout_prob)
self.log_attention_weights = log_attention_weights # whether we should log the attention weights
self.attention_weights = None # for later visualization purposes, I cache the weights here
self.init_params()
def init_params(self):
"""
The reason we're using Glorot (aka Xavier uniform) initialization is because it's a default TF initialization:
https://stackoverflow.com/questions/37350131/what-is-the-default-variable-initializer-in-tensorflow
The original repo was developed in TensorFlow (TF) and they used the default initialization.
Feel free to experiment - there may be better initializations depending on your problem.
"""
nn.init.xavier_uniform_(self.linear_proj.weight)
nn.init.xavier_uniform_(self.scoring_fn_target)
nn.init.xavier_uniform_(self.scoring_fn_source)
if self.bias is not None:
torch.nn.init.zeros_(self.bias)
def skip_concat_bias(self, attention_coefficients, in_nodes_features, out_nodes_features):
if self.log_attention_weights: # potentially log for later visualization in playground.py
self.attention_weights = attention_coefficients
# if the tensor is not contiguously stored in memory we'll get an error after we try to do certain ops like view
# only imp1 will enter this one
if not out_nodes_features.is_contiguous():
out_nodes_features = out_nodes_features.contiguous()
if self.add_skip_connection: # add skip or residual connection
if out_nodes_features.shape[-1] == in_nodes_features.shape[-1]: # if FIN == FOUT
# unsqueeze does this: (N, FIN) -> (N, 1, FIN), out features are (N, NH, FOUT) so 1 gets broadcast to NH
# thus we're basically copying input vectors NH times and adding to processed vectors
out_nodes_features += in_nodes_features.unsqueeze(1)
else:
# FIN != FOUT so we need to project input feature vectors into dimension that can be added to output
# feature vectors. skip_proj adds lots of additional capacity which may cause overfitting.
out_nodes_features += self.skip_proj(in_nodes_features).view(-1, self.num_of_heads,
self.num_out_features)
if self.concat:
# shape = (N, NH, FOUT) -> (N, NH*FOUT)
out_nodes_features = out_nodes_features.view(-1, self.num_of_heads * self.num_out_features)
else:
# shape = (N, NH, FOUT) -> (N, FOUT)
out_nodes_features = out_nodes_features.mean(dim=self.head_dim)
if self.bias is not None:
out_nodes_features += self.bias
return out_nodes_features if self.activation is None else self.activation(out_nodes_features)
class GATLayerImp3(GATLayer):
"""
Implementation #3 was inspired by PyTorch Geometric: https://github.com/rusty1s/pytorch_geometric
But, it's hopefully much more readable! (and of similar performance)
It's suitable for both transductive and inductive settings. In the inductive setting we just merge the graphs
into a single graph with multiple components and this layer is agnostic to that fact! <3
"""
src_nodes_dim = 0 # position of source nodes in edge index
trg_nodes_dim = 1 # position of target nodes in edge index
nodes_dim = 0 # node dimension/axis
head_dim = 1 # attention head dimension/axis
def __init__(self, num_in_features, num_out_features, num_of_heads, device, concat=True, activation=nn.ELU(),
dropout_prob=0.6, add_skip_connection=True, bias=True, log_attention_weights=False):
# Delegate initialization to the base class
super().__init__(num_in_features, num_out_features, num_of_heads, device, concat, activation,
dropout_prob,
add_skip_connection, bias, log_attention_weights)
def forward(self, data):
#
# Step 1: Linear Projection + regularization
#
in_nodes_features, edge_index = data # unpack data
num_of_nodes = in_nodes_features.shape[self.nodes_dim]
assert edge_index.shape[0] == 2, f'Expected edge index with shape=(2,E) got {edge_index.shape}'
# shape = (N, FIN) where N - number of nodes in the graph, FIN - number of input features per node
# We apply the dropout to all of the input node features (as mentioned in the paper)
# Note: for Cora features are already super sparse so it's questionable how much this actually helps
in_nodes_features = self.dropout(in_nodes_features)
# shape = (N, FIN) * (FIN, NH*FOUT) -> (N, NH, FOUT) where NH - number of heads, FOUT - num of output features
# We project the input node features into NH independent output features (one for each attention head)
nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)
nodes_features_proj = self.dropout(nodes_features_proj) # in the official GAT imp they did dropout here as well
#
# Step 2: Edge attention calculation
#
# Apply the scoring function (* represents element-wise (a.k.a. Hadamard) product)
# shape = (N, NH, FOUT) * (1, NH, FOUT) -> (N, NH, 1) -> (N, NH) because sum squeezes the last dimension
# Optimization note: torch.sum() is as performant as .sum() in my experiments
scores_source = (nodes_features_proj * self.scoring_fn_source).sum(dim=-1)
scores_target = (nodes_features_proj * self.scoring_fn_target).sum(dim=-1)
# We simply copy (lift) the scores for source/target nodes based on the edge index. Instead of preparing all
# the possible combinations of scores we just prepare those that will actually be used and those are defined
# by the edge index.
# scores shape = (E, NH), nodes_features_proj_lifted shape = (E, NH, FOUT), E - number of edges in the graph
scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted = self.lift(scores_source, scores_target,
nodes_features_proj,
edge_index)
scores_per_edge = self.leakyReLU(scores_source_lifted + scores_target_lifted)
# shape = (E, NH, 1)
attentions_per_edge = self.neighborhood_aware_softmax(scores_per_edge, edge_index[self.trg_nodes_dim],
num_of_nodes)
# Add stochasticity to neighborhood aggregation
attentions_per_edge = self.dropout(attentions_per_edge)
#
# Step 3: Neighborhood aggregation
#
# Element-wise (aka Hadamard) product. Operator * does the same thing as torch.mul
# shape = (E, NH, FOUT) * (E, NH, 1) -> (E, NH, FOUT), 1 gets broadcast into FOUT
nodes_features_proj_lifted_weighted = nodes_features_proj_lifted * attentions_per_edge
# This part sums up weighted and projected neighborhood feature vectors for every target node
# shape = (N, NH, FOUT)
out_nodes_features = self.aggregate_neighbors(nodes_features_proj_lifted_weighted, edge_index,
in_nodes_features, num_of_nodes)
#
# Step 4: Residual/skip connections, concat and bias
#
out_nodes_features = self.skip_concat_bias(attentions_per_edge, in_nodes_features, out_nodes_features)
return (out_nodes_features, edge_index)
#
# Helper functions (without comments there is very little code so don't be scared!)
#
def neighborhood_aware_softmax(self, scores_per_edge, trg_index, num_of_nodes):
"""
As the fn name suggest it does softmax over the neighborhoods. Example: say we have 5 nodes in a graph.
Two of them 1, 2 are connected to node 3. If we want to calculate the representation for node 3 we should take
into account feature vectors of 1, 2 and 3 itself. Since we have scores for edges 1-3, 2-3 and 3-3
in scores_per_edge variable, this function will calculate attention scores like this: 1-3/(1-3+2-3+3-3)
(where 1-3 is overloaded notation it represents the edge 1-3 and it's (exp) score) and similarly for 2-3 and 3-3
i.e. for this neighborhood we don't care about other edge scores that include nodes 4 and 5.
Note:
Subtracting the max value from logits doesn't change the end result but it improves the numerical stability
and it's a fairly common "trick" used in pretty much every deep learning framework.
Check out this link for more details:
https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning
"""
# Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability)
scores_per_edge = scores_per_edge - scores_per_edge.max()
exp_scores_per_edge = scores_per_edge.exp() # softmax
# Calculate the denominator. shape = (E, NH)
neigborhood_aware_denominator = self.sum_edge_scores_neighborhood_aware(exp_scores_per_edge, trg_index,
num_of_nodes)
# 1e-16 is theoretically not needed but is only there for numerical stability (avoid div by 0) - due to the
# possibility of the computer rounding a very small number all the way to 0.
attentions_per_edge = exp_scores_per_edge / (neigborhood_aware_denominator + 1e-16)
# shape = (E, NH) -> (E, NH, 1) so that we can do element-wise multiplication with projected node features
return attentions_per_edge.unsqueeze(-1)
def sum_edge_scores_neighborhood_aware(self, exp_scores_per_edge, trg_index, num_of_nodes):
# The shape must be the same as in exp_scores_per_edge (required by scatter_add_) i.e. from E -> (E, NH)
trg_index_broadcasted = self.explicit_broadcast(trg_index, exp_scores_per_edge)
# shape = (N, NH), where N is the number of nodes and NH the number of attention heads
size = list(exp_scores_per_edge.shape) # convert to list otherwise assignment is not possible
size[self.nodes_dim] = num_of_nodes
neighborhood_sums = torch.zeros(size, dtype=exp_scores_per_edge.dtype, device=exp_scores_per_edge.device)
# position i will contain a sum of exp scores of all the nodes that point to the node i (as dictated by the
# target index)
neighborhood_sums.scatter_add_(self.nodes_dim, trg_index_broadcasted.to(device=self.device), exp_scores_per_edge.to(device=self.device))
# Expand again so that we can use it as a softmax denominator. e.g. node i's sum will be copied to
# all the locations where the source nodes pointed to i (as dictated by the target index)
# shape = (N, NH) -> (E, NH)
return neighborhood_sums.index_select(self.nodes_dim, trg_index.to(device=self.device))
def aggregate_neighbors(self, nodes_features_proj_lifted_weighted, edge_index, in_nodes_features, num_of_nodes):
size = list(nodes_features_proj_lifted_weighted.shape) # convert to list otherwise assignment is not possible
size[self.nodes_dim] = num_of_nodes # shape = (N, NH, FOUT)
out_nodes_features = torch.zeros(size, dtype=in_nodes_features.dtype, device=in_nodes_features.device)
# shape = (E) -> (E, NH, FOUT)
trg_index_broadcasted = self.explicit_broadcast(edge_index[self.trg_nodes_dim],
nodes_features_proj_lifted_weighted)
# aggregation step - we accumulate projected, weighted node features for all the attention heads
# shape = (E, NH, FOUT) -> (N, NH, FOUT)
out_nodes_features.scatter_add_(self.nodes_dim, trg_index_broadcasted.to(device=self.device), nodes_features_proj_lifted_weighted.to(device=self.device))
return out_nodes_features
def lift(self, scores_source, scores_target, nodes_features_matrix_proj, edge_index):
"""
Lifts i.e. duplicates certain vectors depending on the edge index.
One of the tensor dims goes from N -> E (that's where the "lift" comes from).
"""
src_nodes_index = edge_index[self.src_nodes_dim]
trg_nodes_index = edge_index[self.trg_nodes_dim]
# Using index_select is faster than "normal" indexing (scores_source[src_nodes_index]) in PyTorch!
scores_source = scores_source.index_select(self.nodes_dim, src_nodes_index.to(device=self.device))
scores_target = scores_target.index_select(self.nodes_dim, trg_nodes_index.to(device=self.device))
nodes_features_matrix_proj_lifted = nodes_features_matrix_proj.index_select(self.nodes_dim, src_nodes_index.to(device=self.device))
return scores_source, scores_target, nodes_features_matrix_proj_lifted
def explicit_broadcast(self, this, other):
# Append singleton dimensions until this.dim() == other.dim()
for _ in range(this.dim(), other.dim()):
this = this.unsqueeze(-1)
# Explicitly expand so that shapes are the same
return this.expand_as(other)
class GAT(nn.Module):
def __init__(self, in_size, hid_size, out_size, layer_num, head_num):
super().__init__()
if layer_num == 0:
self.gat_linear = nn.Linear(in_size, out_size)
elif layer_num == 1:
self.gat_layers = nn.ModuleList(
[dglnn.GATConv(in_size, out_size, head_num, activation=F.elu)]
)
self.gat_linear = nn.Linear(out_size*head_num, out_size)
else:
self.gat_layers = nn.ModuleList()
self.gat_layers.append(dglnn.GATConv(in_size, hid_size, head_num, activation=F.elu))
for _ in range(layer_num - 2):
self.gat_layers.append(dglnn.GATConv(hid_size*head_num, hid_size, head_num, activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*head_num, out_size, 1, activation=F.elu))
self.gat_linear = nn.Identity()
def forward(self, g, inputs):
h = inputs
for i, layer in enumerate(self.gat_layers):
h = layer(g, h)
if i == len(self.gat_layers) - 1: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
return h
class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size, layer_num):
super().__init__()
self.layers = nn.ModuleList()
if layer_num == 1:
self.layers.append(
dglnn.GraphConv(in_size, out_size)
)
else:
# two-layer GCN
self.layers.append(
dglnn.GraphConv(in_size, hid_size, activation=F.relu)
)
for i in range(layer_num - 2):
self.layers.append(dglnn.GraphConv(hid_size, hid_size))
self.layers.append(dglnn.GraphConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5)
def forward(self, g, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(g, h)
return h
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size, layer_num):
super().__init__()
self.layers = nn.ModuleList()
if layer_num == 1:
self.layers.append(dglnn.SAGEConv(in_size, out_size, "gcn"))
else:
# two-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))
for i in range(layer_num - 2):
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "gcn"))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))
self.dropout = nn.Dropout(0.5)
def forward(self, graph, x):
h = self.dropout(x)
for l, layer in enumerate(self.layers):
h = layer(graph, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h