-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathutils.py
131 lines (106 loc) · 4.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
import torch
from torch_geometric.utils import remove_self_loops
from torch_scatter import scatter
from yacs.config import CfgNode
def negate_edge_index(edge_index, batch=None):
"""Negate batched sparse adjacency matrices given by edge indices.
Returns batched sparse adjacency matrices with exactly those edges that
are not in the input `edge_index` while ignoring self-loops.
Implementation inspired by `torch_geometric.utils.to_dense_adj`
Args:
edge_index: The edge indices.
batch: Batch vector, which assigns each node to a specific example.
Returns:
Complementary edge index.
"""
if batch is None:
batch = edge_index.new_zeros(edge_index.max().item() + 1)
batch_size = batch.max().item() + 1
one = batch.new_ones(batch.size(0))
num_nodes = scatter(one, batch,
dim=0, dim_size=batch_size, reduce='add')
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
idx0 = batch[edge_index[0]]
idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
negative_index_list = []
for i in range(batch_size):
n = num_nodes[i].item()
size = [n, n]
adj = torch.ones(size, dtype=torch.short,
device=edge_index.device)
# Remove existing edges from the full N x N adjacency matrix
flattened_size = n * n
adj = adj.view([flattened_size])
_idx1 = idx1[idx0 == i]
_idx2 = idx2[idx0 == i]
idx = _idx1 * n + _idx2
zero = torch.zeros(_idx1.numel(), dtype=torch.short,
device=edge_index.device)
scatter(zero, idx, dim=0, out=adj, reduce='mul')
# Convert to edge index format
adj = adj.view(size)
_edge_index = adj.nonzero(as_tuple=False).t().contiguous()
_edge_index, _ = remove_self_loops(_edge_index)
negative_index_list.append(_edge_index + cum_nodes[i])
edge_index_negative = torch.cat(negative_index_list, dim=1).contiguous()
return edge_index_negative
def flatten_dict(metrics):
"""Flatten a list of train/val/test metrics into one dict to send to wandb.
Args:
metrics: List of Dicts with metrics
Returns:
A flat dictionary with names prefixed with "train/" , "val/" , "test/"
"""
prefixes = ['train', 'val', 'test']
result = {}
for i in range(len(metrics)):
# Take the latest metrics.
stats = metrics[i][-1]
result.update({f"{prefixes[i]}/{k}": v for k, v in stats.items()})
return result
def cfg_to_dict(cfg_node, key_list=[]):
"""Convert a config node to dictionary.
Yacs doesn't have a default function to convert the cfg object to plain
python dict. The following function was taken from
https://github.com/rbgirshick/yacs/issues/19
"""
_VALID_TYPES = {tuple, list, str, int, float, bool}
if not isinstance(cfg_node, CfgNode):
if type(cfg_node) not in _VALID_TYPES:
logging.warning(f"Key {'.'.join(key_list)} with "
f"value {type(cfg_node)} is not "
f"a valid type; valid types: {_VALID_TYPES}")
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = cfg_to_dict(v, key_list + [k])
return cfg_dict
def make_wandb_name(cfg):
# Format dataset name.
dataset_name = cfg.dataset.format
if dataset_name.startswith('OGB'):
dataset_name = dataset_name[3:]
if dataset_name.startswith('PyG-'):
dataset_name = dataset_name[4:]
if dataset_name in ['GNNBenchmarkDataset', 'TUDataset']:
# Shorten some verbose dataset naming schemes.
dataset_name = ""
if cfg.dataset.name != 'none':
dataset_name += "-" if dataset_name != "" else ""
if cfg.dataset.name == 'LocalDegreeProfile':
dataset_name += 'LDP'
else:
dataset_name += cfg.dataset.name
# Format model name.
model_name = cfg.model.type
if cfg.model.type in ['gnn', 'custom_gnn']:
model_name += f".{cfg.gnn.layer_type}"
elif cfg.model.type == 'GPSModel':
model_name = f"GPS.{cfg.gt.layer_type}"
model_name += f".{cfg.name_tag}" if cfg.name_tag else ""
# Compose wandb run name.
name = f"{dataset_name}.{model_name}.r{cfg.run_id}"
return name