Skip to content

Commit

Permalink
Add support for edge weights to GraphSAGE sampling (#1667)
Browse files Browse the repository at this point in the history
This expands GraphSAGE (undirected and directed) to support weighted sampling,
where edges with higher weights are taken proportionally more often.

For example, suppose there's there's 4 edges from node A:

| source | target | weight |
|--------|--------|--------|
| A      | B      | 0      |
| A      | C      | 1      |
| A      | D      | 2      |
| A      | D      | 3      |

An unweighed walk starting at A will choose each of the edges with equal
propability and so end up on B, C or D in proportion 1:1:2 (edge counts). A
weighted walk will choose the edges proportional to the weights, so end up on
the vertices in proportion 0:1:5 (sum of edge weight). (Worth specifically
highlighting: a weighted walk will never chose the A-B edge because it has
weight 0.)

See: #462
  • Loading branch information
huonw authored Jun 12, 2020
1 parent ce8790a commit 84336e1
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 82 deletions.
2 changes: 1 addition & 1 deletion demos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ These demos are displayed with detailed descriptions in the documentation: https
<td>GraphSAGE</td>
<td>see HinSAGE</td>
<td><a href='node-classification/directed-graphsage-node-classification.ipynb'>demo</a></td>
<td></td>
<td>yes</td>
<td></td>
<td>yes</td>
<td><a href='node-classification/graphsage-node-classification.ipynb'>demo</a></td>
Expand Down
2 changes: 1 addition & 1 deletion docs/demos/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ Find a demo for an algorithm
- GraphSAGE
- see HinSAGE
- :any:`demo <node-classification/directed-graphsage-node-classification>`
-
- yes
-
- yes
- :any:`demo <node-classification/graphsage-node-classification>`
Expand Down
1 change: 1 addition & 0 deletions scripts/demo_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def via_rl(link=None):
"GraphSAGE",
heterogeneous="see HinSAGE",
directed=T(link="node-classification/directed-graphsage-node-classification"),
weighted=True,
features=True,
nc=T(link="node-classification/graphsage-node-classification"),
lp=T(link="link-prediction/graphsage-link-prediction"),
Expand Down
143 changes: 78 additions & 65 deletions stellargraph/data/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,10 @@ def _get_random_state(self, seed):
"""
if seed is None:
# Restore the random state
return self._random_state
return self._random_state, self._np_random_state
# seed the random number generator
require_integer_in_range(seed, "seed", min_val=0)
rs, _ = random_state(seed)
return rs
return random_state(seed)

@staticmethod
def _validate_walk_params(nodes, n, length):
Expand Down Expand Up @@ -154,10 +153,9 @@ def _get_random_state(self, seed):
"""
if seed is None:
# Use the class's random state
return self._random_state
# seed the random number generator
rs, _ = random_state(seed)
return rs
return self._random_state, self._np_random_state
# seed the random number generators
return random_state(seed)

def neighbors(self, node):
return self.graph.neighbor_arrays(node, use_ilocs=True)
Expand Down Expand Up @@ -233,6 +231,37 @@ def _check_sizes(self, n_size):
if type(d) != int or d < 0:
self._raise_error(err_msg)

def _sample_neighbours_untyped(
self, neigh_func, py_and_np_rs, cur_node, size, weighted
):
"""
Sample ``size`` neighbours of ``cur_node`` without checking node types or edge types, optionally
using edge weights.
"""
if cur_node != -1:
neighbours = neigh_func(
cur_node, use_ilocs=True, include_edge_weight=weighted
)

if weighted:
neighbours, weights = neighbours
else:
neighbours = []

if len(neighbours) > 0:
if weighted:
# sample following the edge weights
idx = naive_weighted_choices(py_and_np_rs[1], weights, size=size)
if idx is not None:
return neighbours[idx]
else:
# uniform sample; for small-to-moderate `size`s (< 100 is typical for GraphSAGE), random
# has less overhead than np.random
return np.array(py_and_np_rs[0].choices(neighbours, k=size))

# no neighbours (e.g. isolated node, cur_node == -1 or all weights 0), so propagate the -1 sentinel
return np.full(size, -1)


class UniformRandomWalk(RandomWalk):
"""
Expand Down Expand Up @@ -269,7 +298,7 @@ def run(self, nodes, *, n=None, length=None, seed=None):
n = _default_if_none(n, self.n, "n")
length = _default_if_none(length, self.length, "length")
self._validate_walk_params(nodes, n, length)
rs = self._get_random_state(seed)
rs, _ = self._get_random_state(seed)

nodes = self.graph.node_ids_to_ilocs(nodes)

Expand All @@ -292,9 +321,9 @@ def _walk(self, rs, start_node, length):
return list(self.graph.node_ilocs_to_ids(walk))


def naive_weighted_choices(rs, weights):
def naive_weighted_choices(rs, weights, size=None):
"""
Select an index at random, weighted by the iterator `weights` of
Select indices at random, weighted by the iterator `weights` of
arbitrary (non-negative) floats. That is, `x` will be returned
with probability `weights[x]/sum(weights)`.
Expand All @@ -304,7 +333,13 @@ def naive_weighted_choices(rs, weights):
does a lot of conversions/checks/preprocessing internally.
"""
probs = np.cumsum(weights)
idx = np.searchsorted(probs, rs.random() * probs[-1], side="left")
total = probs[-1]
if total == 0:
# all weights were zero (probably), so we shouldn't choose anything
return None

thresholds = rs.random() if size is None else rs.random(size)
idx = np.searchsorted(probs, thresholds * total, side="left")

return idx

Expand Down Expand Up @@ -392,7 +427,7 @@ def run(
weighted = _default_if_none(weighted, self.weighted, "weighted")
self._validate_walk_params(nodes, n, length)
self._check_weights(p, q, weighted)
rs = self._get_random_state(seed)
rs, _ = self._get_random_state(seed)

nodes = self.graph.node_ids_to_ilocs(nodes)

Expand Down Expand Up @@ -445,6 +480,8 @@ def run(
weights[~mask] *= iq

choice = naive_weighted_choices(rs, weights)
if choice is None:
break

previous_node = current_node
previous_node_neighbours = neighbours
Expand Down Expand Up @@ -521,7 +558,7 @@ def run(self, nodes, *, n=None, length=None, metapaths=None, seed=None):
metapaths = _default_if_none(metapaths, self.metapaths, "metapaths")
self._validate_walk_params(nodes, n, length)
self._check_metapath_values(metapaths)
rs = self._get_random_state(seed)
rs, _ = self._get_random_state(seed)

nodes = self.graph.node_ids_to_ilocs(nodes)

Expand Down Expand Up @@ -616,7 +653,7 @@ class SampledBreadthFirstWalk(GraphWalk):
It can be used to extract a random sub-graph starting from a set of initial nodes.
"""

def run(self, nodes, n_size, n=1, seed=None):
def run(self, nodes, n_size, n=1, seed=None, weighted=False):
"""
Performs a sampled breadth-first walk starting from the root nodes.
Expand All @@ -629,13 +666,14 @@ def run(self, nodes, n_size, n=1, seed=None):
number of neighbours requested.
n (int): Number of walks per node id.
seed (int, optional): Random number generator seed; Default is None.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
Returns:
A list of lists such that each list element is a sequence of ids corresponding to a BFW.
"""
self._check_sizes(n_size)
self._check_common_parameters(nodes, n, len(n_size), seed)
rs = self._get_random_state(seed)
py_and_np_rs = self._get_random_state(seed)

walks = []
max_hops = len(n_size) # depth of search
Expand All @@ -658,18 +696,13 @@ def run(self, nodes, n_size, n=1, seed=None):
if depth > max_hops:
continue

neighbours = (
self.graph.neighbor_arrays(cur_node, use_ilocs=True)
if cur_node != -1
else []
neighbours = self._sample_neighbours_untyped(
self.graph.neighbor_arrays,
py_and_np_rs,
cur_node,
n_size[cur_depth],
weighted,
)
if len(neighbours) == 0:
# Either node is unconnected or is in directed graph with no out-nodes.
_size = n_size[cur_depth]
neighbours = [-1] * _size
else:
# sample with replacement
neighbours = rs.choices(neighbours, k=n_size[cur_depth])

# add them to the back of the queue
q.extend((sampled_node, depth) for sampled_node in neighbours)
Expand Down Expand Up @@ -704,7 +737,7 @@ def run(self, nodes, n_size, n=1, seed=None):
"""
self._check_sizes(n_size)
self._check_common_parameters(nodes, n, len(n_size), seed)
rs = self._get_random_state(seed)
rs, _ = self._get_random_state(seed)

adj = self.get_adjacency_types()

Expand Down Expand Up @@ -773,7 +806,7 @@ def __init__(self, graph, graph_schema=None, seed=None):
if not graph.is_directed():
self._raise_error("Graph must be directed")

def run(self, nodes, in_size, out_size, n=1, seed=None):
def run(self, nodes, in_size, out_size, n=1, seed=None, weighted=False):
"""
Performs a sampled breadth-first walk starting from the root nodes.
Expand All @@ -784,6 +817,7 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):
out_size (int): The number of out-directed nodes to sample with replacement at each depth of the walk.
n (int, default 1): Number of walks per node id.
seed (int, optional): Random number generator seed; default is None
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
Returns:
Expand All @@ -803,7 +837,7 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):
"""
self._check_neighbourhood_sizes(in_size, out_size)
self._check_common_parameters(nodes, n, len(in_size), seed)
rs = self._get_random_state(seed)
py_and_np_rs = self._get_random_state(seed)

max_hops = len(in_size)
# A binary tree is a graph of nodes; however, we wish to avoid overusing the term 'node'.
Expand Down Expand Up @@ -834,17 +868,25 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):
if depth > max_hops:
continue
# get in-nodes
neighbours = self._sample_neighbours(
rs, cur_node, 0, in_size[cur_depth]
neighbours = self._sample_neighbours_untyped(
self.graph.in_node_arrays,
py_and_np_rs,
cur_node,
in_size[cur_depth],
weighted,
)
# add them to the back of the queue
slot = 2 * cur_slot + 1
q.extend(
[(sampled_node, depth, slot) for sampled_node in neighbours]
)
# get out-nodes
neighbours = self._sample_neighbours(
rs, cur_node, 1, out_size[cur_depth]
neighbours = self._sample_neighbours_untyped(
self.graph.out_node_arrays,
py_and_np_rs,
cur_node,
out_size[cur_depth],
weighted,
)
# add them to the back of the queue
slot = slot + 1
Expand All @@ -857,37 +899,6 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):

return samples

def _sample_neighbours(self, rs, node, idx, size):
"""
Samples (with replacement) the specified number of nodes
from the directed neighbourhood of the given starting node.
If the neighbourhood is empty, then the result will contain
only None values.
Args:
rs: The random state used for sampling.
node: The starting node.
idx: <int> The index specifying the direction of the
neighbourhood to be sampled: 0 => in-nodes;
1 => out-nodes.
size: <int> The number of nodes to sample.
Returns:
The fixed-length list of neighbouring nodes (or None values
if the neighbourhood is empty).
"""
if node == -1:
# Non-node, e.g. previously sampled from empty neighbourhood
return [-1] * size

if idx == 0:
neighbours = self.graph.in_node_arrays(node, use_ilocs=True)
else:
neighbours = self.graph.out_node_arrays(node, use_ilocs=True)
if len(neighbours) == 0:
# Sampling from empty neighbourhood
return [-1] * size
# Sample with replacement
return rs.choices(neighbours, k=size)

def _check_neighbourhood_sizes(self, in_size, out_size):
"""
Checks that the parameter values are valid or raises ValueError exceptions with a message indicating the
Expand Down Expand Up @@ -1034,7 +1045,7 @@ def run(
f"max_walk_length: maximum walk length should not be less than the context window size, found {max_walk_length}"
)

np_rs = self._np_random_state if seed is None else np.random.RandomState(seed)
_, np_rs = self._get_random_state(seed)
walks = []
num_cw_curr = 0

Expand Down Expand Up @@ -1117,6 +1128,8 @@ def _step(self, node, time, bias_type, np_rs):
if len(neighbours) > 0:
biases = self._temporal_biases(times, time, bias_type, is_forward=True)
chosen_neighbour_index = self._sample(len(neighbours), biases, np_rs)
assert chosen_neighbour_index is not None, "biases should never be all zero"

next_node = neighbours[chosen_neighbour_index]
next_time = times[chosen_neighbour_index]
return next_node, next_time
Expand Down
27 changes: 23 additions & 4 deletions stellargraph/mapper/sampled_link_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,17 @@ class GraphSAGELinkGenerator(BatchedLinkGenerator):
batch_size (int): Size of batch of links to return.
num_samples (list): List of number of neighbour node samples per GraphSAGE layer (hop) to take.
seed (int or str), optional: Random seed for the sampling methods.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
"""

def __init__(self, G, batch_size, num_samples, seed=None, name=None):
def __init__(
self, G, batch_size, num_samples, seed=None, name=None, weighted=False
):
super().__init__(G, batch_size)

self.num_samples = num_samples
self.name = name
self.weighted = weighted

# Check that there is only a single node type for GraphSAGE
if len(self.schema.node_types) > 1:
Expand Down Expand Up @@ -285,7 +289,7 @@ def get_levels(loc, lsize, samples_per_hop, walks):
batch_feats = []
for hns in zip(*head_links):
node_samples = self._samplers[batch_num].run(
nodes=hns, n=1, n_size=self.num_samples
nodes=hns, n=1, n_size=self.num_samples, weighted=self.weighted
)

nodes_per_hop = get_levels(0, 1, self.num_samples, node_samples)
Expand Down Expand Up @@ -582,14 +586,25 @@ class DirectedGraphSAGELinkGenerator(BatchedLinkGenerator):
out_samples (list): The number of out-node samples per layer (hop) to take.
seed (int or str), optional: Random seed for the sampling methods.
name, optional: Name of generator.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
"""

def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None):
def __init__(
self,
G,
batch_size,
in_samples,
out_samples,
seed=None,
name=None,
weighted=False,
):
super().__init__(G, batch_size)

self.in_samples = in_samples
self.out_samples = out_samples
self._name = name
self.weighted = weighted

# Check that there is only a single node type for GraphSAGE
if len(self.schema.node_types) > 1:
Expand Down Expand Up @@ -631,7 +646,11 @@ def sample_features(self, head_links, batch_num):
for hns in zip(*head_links):

node_samples = self._samplers[batch_num].run(
nodes=hns, n=1, in_size=self.in_samples, out_size=self.out_samples
nodes=hns,
n=1,
in_size=self.in_samples,
out_size=self.out_samples,
weighted=self.weighted,
)

# Reshape node samples to sensible format
Expand Down
Loading

0 comments on commit 84336e1

Please sign in to comment.