Skip to content

Commit

Permalink
Add support for edge weights to GraphSAGE sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
huonw committed Jun 10, 2020
1 parent 3bf3330 commit c362597
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 66 deletions.
1 change: 1 addition & 0 deletions scripts/demo_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def via_rl(link=None):
"GraphSAGE",
heterogeneous="see HinSAGE",
directed=T(link="node-classification/directed-graphsage-node-classification"),
weighted=True,
features=True,
nc=T(link="node-classification/graphsage-node-classification"),
lp=T(link="link-prediction/graphsage-link-prediction"),
Expand Down
102 changes: 51 additions & 51 deletions stellargraph/data/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,32 @@ def _check_sizes(self, n_size):
if type(d) != int or d < 0:
self._raise_error(err_msg)

def _sample_neighbours_untyped(self, neigh_func, rs, cur_node, size, weighted):
"""
Sample ``size`` neighbours of ``cur_node`` without checking node types or edge types, optionally
using edge weights.
"""
if cur_node != -1:
neighbours = neigh_func(
cur_node, use_ilocs=True, include_edge_weight=weighted
)

if weighted:
neighbours, weights = neighbours
else:
neighbours = []

if len(neighbours) == 0:
# no neighbours (e.g. isolated node or cur_node == -1), so propagate the -1 sentinel
return np.full(size, -1)
elif weighted:
# sample following the edge weights
idx = naive_weighted_choices(rs, weights, size=size)
return neighbours[idx]
else:
# uniform sample
return rs.choice(neighbours, size=size)


class UniformRandomWalk(RandomWalk):
"""
Expand Down Expand Up @@ -292,9 +318,9 @@ def _walk(self, rs, start_node, length):
return list(self.graph.node_ilocs_to_ids(walk))


def naive_weighted_choices(rs, weights):
def naive_weighted_choices(rs, weights, size=None):
"""
Select an index at random, weighted by the iterator `weights` of
Select indices at random, weighted by the iterator `weights` of
arbitrary (non-negative) floats. That is, `x` will be returned
with probability `weights[x]/sum(weights)`.
Expand All @@ -304,7 +330,7 @@ def naive_weighted_choices(rs, weights):
does a lot of conversions/checks/preprocessing internally.
"""
probs = np.cumsum(weights)
idx = np.searchsorted(probs, rs.random() * probs[-1], side="left")
idx = np.searchsorted(probs, rs.random(size) * probs[-1], side="left")

return idx

Expand Down Expand Up @@ -616,7 +642,7 @@ class SampledBreadthFirstWalk(GraphWalk):
It can be used to extract a random sub-graph starting from a set of initial nodes.
"""

def run(self, nodes, n_size, n=1, seed=None):
def run(self, nodes, n_size, n=1, seed=None, weighted=False):
"""
Performs a sampled breadth-first walk starting from the root nodes.
Expand All @@ -629,6 +655,7 @@ def run(self, nodes, n_size, n=1, seed=None):
number of neighbours requested.
n (int): Number of walks per node id.
seed (int, optional): Random number generator seed; Default is None.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
Returns:
A list of lists such that each list element is a sequence of ids corresponding to a BFW.
Expand Down Expand Up @@ -658,18 +685,13 @@ def run(self, nodes, n_size, n=1, seed=None):
if depth > max_hops:
continue

neighbours = (
self.graph.neighbor_arrays(cur_node, use_ilocs=True)
if cur_node != -1
else []
neighbours = self._sample_neighbours_untyped(
self.graph.neighbor_arrays,
rs,
cur_node,
n_size[cur_depth],
weighted,
)
if len(neighbours) == 0:
# Either node is unconnected or is in directed graph with no out-nodes.
_size = n_size[cur_depth]
neighbours = [-1] * _size
else:
# sample with replacement
neighbours = rs.choice(neighbours, size=n_size[cur_depth])

# add them to the back of the queue
q.extend((sampled_node, depth) for sampled_node in neighbours)
Expand Down Expand Up @@ -773,7 +795,7 @@ def __init__(self, graph, graph_schema=None, seed=None):
if not graph.is_directed():
self._raise_error("Graph must be directed")

def run(self, nodes, in_size, out_size, n=1, seed=None):
def run(self, nodes, in_size, out_size, n=1, seed=None, weighted=False):
"""
Performs a sampled breadth-first walk starting from the root nodes.
Expand All @@ -784,6 +806,7 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):
out_size (int): The number of out-directed nodes to sample with replacement at each depth of the walk.
n (int, default 1): Number of walks per node id.
seed (int, optional): Random number generator seed; default is None
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
Returns:
Expand Down Expand Up @@ -834,17 +857,25 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):
if depth > max_hops:
continue
# get in-nodes
neighbours = self._sample_neighbours(
rs, cur_node, 0, in_size[cur_depth]
neighbours = self._sample_neighbours_untyped(
self.graph.in_node_arrays,
rs,
cur_node,
in_size[cur_depth],
weighted,
)
# add them to the back of the queue
slot = 2 * cur_slot + 1
q.extend(
[(sampled_node, depth, slot) for sampled_node in neighbours]
)
# get out-nodes
neighbours = self._sample_neighbours(
rs, cur_node, 1, out_size[cur_depth]
neighbours = self._sample_neighbours_untyped(
self.graph.out_node_arrays,
rs,
cur_node,
out_size[cur_depth],
weighted,
)
# add them to the back of the queue
slot = slot + 1
Expand All @@ -857,37 +888,6 @@ def run(self, nodes, in_size, out_size, n=1, seed=None):

return samples

def _sample_neighbours(self, rs, node, idx, size):
"""
Samples (with replacement) the specified number of nodes
from the directed neighbourhood of the given starting node.
If the neighbourhood is empty, then the result will contain
only None values.
Args:
rs: The random state used for sampling.
node: The starting node.
idx: <int> The index specifying the direction of the
neighbourhood to be sampled: 0 => in-nodes;
1 => out-nodes.
size: <int> The number of nodes to sample.
Returns:
The fixed-length list of neighbouring nodes (or None values
if the neighbourhood is empty).
"""
if node == -1:
# Non-node, e.g. previously sampled from empty neighbourhood
return [-1] * size

if idx == 0:
neighbours = self.graph.in_node_arrays(node, use_ilocs=True)
else:
neighbours = self.graph.out_node_arrays(node, use_ilocs=True)
if len(neighbours) == 0:
# Sampling from empty neighbourhood
return [-1] * size
# Sample with replacement
return rs.choice(neighbours, size=size)

def _check_neighbourhood_sizes(self, in_size, out_size):
"""
Checks that the parameter values are valid or raises ValueError exceptions with a message indicating the
Expand Down
27 changes: 23 additions & 4 deletions stellargraph/mapper/sampled_link_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,17 @@ class GraphSAGELinkGenerator(BatchedLinkGenerator):
batch_size (int): Size of batch of links to return.
num_samples (list): List of number of neighbour node samples per GraphSAGE layer (hop) to take.
seed (int or str), optional: Random seed for the sampling methods.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
"""

def __init__(self, G, batch_size, num_samples, seed=None, name=None):
def __init__(
self, G, batch_size, num_samples, seed=None, name=None, weighted=False
):
super().__init__(G, batch_size)

self.num_samples = num_samples
self.name = name
self.weighted = weighted

# Check that there is only a single node type for GraphSAGE
if len(self.schema.node_types) > 1:
Expand Down Expand Up @@ -285,7 +289,7 @@ def get_levels(loc, lsize, samples_per_hop, walks):
batch_feats = []
for hns in zip(*head_links):
node_samples = self._samplers[batch_num].run(
nodes=hns, n=1, n_size=self.num_samples
nodes=hns, n=1, n_size=self.num_samples, weighted=self.weighted
)

nodes_per_hop = get_levels(0, 1, self.num_samples, node_samples)
Expand Down Expand Up @@ -582,14 +586,25 @@ class DirectedGraphSAGELinkGenerator(BatchedLinkGenerator):
out_samples (list): The number of out-node samples per layer (hop) to take.
seed (int or str), optional: Random seed for the sampling methods.
name, optional: Name of generator.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
"""

def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None):
def __init__(
self,
G,
batch_size,
in_samples,
out_samples,
seed=None,
name=None,
weighted=False,
):
super().__init__(G, batch_size)

self.in_samples = in_samples
self.out_samples = out_samples
self._name = name
self.weighted = weighted

# Check that there is only a single node type for GraphSAGE
if len(self.schema.node_types) > 1:
Expand Down Expand Up @@ -631,7 +646,11 @@ def sample_features(self, head_links, batch_num):
for hns in zip(*head_links):

node_samples = self._samplers[batch_num].run(
nodes=hns, n=1, in_size=self.in_samples, out_size=self.out_samples
nodes=hns,
n=1,
in_size=self.in_samples,
out_size=self.out_samples,
weighted=self.weighted,
)

# Reshape node samples to sensible format
Expand Down
27 changes: 23 additions & 4 deletions stellargraph/mapper/sampled_node_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,18 @@ class GraphSAGENodeGenerator(BatchedNodeGenerator):
batch_size (int): Size of batch to return.
num_samples (list): The number of samples per layer (hop) to take.
seed (int): [Optional] Random seed for the node sampler.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
"""

def __init__(self, G, batch_size, num_samples, seed=None, name=None):
def __init__(
self, G, batch_size, num_samples, seed=None, name=None, weighted=False
):
super().__init__(G, batch_size)

self.num_samples = num_samples
self.head_node_types = self.schema.node_types
self.name = name
self.weighted = weighted

# Check that there is only a single node type for GraphSAGE
if len(self.head_node_types) > 1:
Expand Down Expand Up @@ -241,7 +245,7 @@ def sample_features(self, head_nodes, batch_num):
for that layer.
"""
node_samples = self._samplers[batch_num].run(
nodes=head_nodes, n=1, n_size=self.num_samples
nodes=head_nodes, n=1, n_size=self.num_samples, weighted=self.weighted
)

# The number of samples for each head node (not including itself)
Expand Down Expand Up @@ -304,16 +308,27 @@ class DirectedGraphSAGENodeGenerator(BatchedNodeGenerator):
in_samples (list): The number of in-node samples per layer (hop) to take.
out_samples (list): The number of out-node samples per layer (hop) to take.
seed (int): [Optional] Random seed for the node sampler.
weighted (bool, optional): If True, sample neighbours using the edge weights in the graph.
"""

def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None):
def __init__(
self,
G,
batch_size,
in_samples,
out_samples,
seed=None,
name=None,
weighted=False,
):
super().__init__(G, batch_size)

# TODO Add checks for in- and out-nodes sizes
self.in_samples = in_samples
self.out_samples = out_samples
self.head_node_types = self.schema.node_types
self.name = name
self.weighted = weighted

# Check that there is only a single node type for GraphSAGE
if len(self.head_node_types) > 1:
Expand Down Expand Up @@ -350,7 +365,11 @@ def sample_features(self, head_nodes, batch_num):
given the sequence of in/out directions.
"""
node_samples = self.sampler.run(
nodes=head_nodes, n=1, in_size=self.in_samples, out_size=self.out_samples
nodes=head_nodes,
n=1,
in_size=self.in_samples,
out_size=self.out_samples,
weighted=self.weighted,
)

# Reshape node samples to sensible format
Expand Down
21 changes: 17 additions & 4 deletions tests/data/test_breadth_first_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
import pytest
import numpy as np
from stellargraph.data.explorer import SampledBreadthFirstWalk
from stellargraph.core.graph import StellarDiGraph
from ..test_utils.graphs import create_test_graph, tree_graph, example_graph_random
from stellargraph.core.graph import StellarGraph, StellarDiGraph
from ..test_utils.graphs import (
create_test_graph,
tree_graph,
example_graph_random,
weighted_tree,
)


def expected_bfw_size(n_size):
Expand Down Expand Up @@ -539,12 +544,20 @@ def test_fixed_random_seed(self):
assert len(w0) == len(w1)
assert w0 == w1

def test_benchmark_bfs_walk(self, benchmark):
@pytest.mark.parametrize("weighted", [False, True])
def test_benchmark_bfs_walk(self, benchmark, weighted):
g = example_graph_random(n_nodes=100, n_edges=500)
bfw = SampledBreadthFirstWalk(g)

nodes = np.arange(0, 50)
n = 5
n_size = [5, 5]

benchmark(lambda: bfw.run(nodes=nodes, n=n, n_size=n_size))
benchmark(lambda: bfw.run(nodes=nodes, n=n, n_size=n_size, weighted=weighted))

def test_weighted(self):
g, checker = weighted_tree()
bfw = SampledBreadthFirstWalk(g)
walks = bfw.run(nodes=[0], n=10, n_size=[20, 20], weighted=True)

checker(node_id for walk in walks for node_id in walk)
Loading

0 comments on commit c362597

Please sign in to comment.