diff --git a/scripts/demo_indexing.py b/scripts/demo_indexing.py index 699af8200..70ad53f5a 100644 --- a/scripts/demo_indexing.py +++ b/scripts/demo_indexing.py @@ -450,6 +450,7 @@ def via_rl(link=None): "GraphSAGE", heterogeneous="see HinSAGE", directed=T(link="node-classification/directed-graphsage-node-classification"), + weighted=True, features=True, nc=T(link="node-classification/graphsage-node-classification"), lp=T(link="link-prediction/graphsage-link-prediction"), diff --git a/stellargraph/data/explorer.py b/stellargraph/data/explorer.py index d7d20754c..9b5d57c2b 100644 --- a/stellargraph/data/explorer.py +++ b/stellargraph/data/explorer.py @@ -233,6 +233,32 @@ def _check_sizes(self, n_size): if type(d) != int or d < 0: self._raise_error(err_msg) + def _sample_neighbours_untyped(self, neigh_func, rs, cur_node, size, weighted): + """ + Sample ``size`` neighbours of ``cur_node`` without checking node types or edge types, optionally + using edge weights. + """ + if cur_node != -1: + neighbours = neigh_func( + cur_node, use_ilocs=True, include_edge_weight=weighted + ) + + if weighted: + neighbours, weights = neighbours + else: + neighbours = [] + + if len(neighbours) == 0: + # no neighbours (e.g. isolated node or cur_node == -1), so propagate the -1 sentinel + return np.full(size, -1) + elif weighted: + # sample following the edge weights + idx = naive_weighted_choices(rs, weights, size=size) + return neighbours[idx] + else: + # uniform sample + return rs.choice(neighbours, size=size) + class UniformRandomWalk(RandomWalk): """ @@ -292,9 +318,9 @@ def _walk(self, rs, start_node, length): return list(self.graph.node_ilocs_to_ids(walk)) -def naive_weighted_choices(rs, weights): +def naive_weighted_choices(rs, weights, size=None): """ - Select an index at random, weighted by the iterator `weights` of + Select indices at random, weighted by the iterator `weights` of arbitrary (non-negative) floats. That is, `x` will be returned with probability `weights[x]/sum(weights)`. @@ -304,7 +330,7 @@ def naive_weighted_choices(rs, weights): does a lot of conversions/checks/preprocessing internally. """ probs = np.cumsum(weights) - idx = np.searchsorted(probs, rs.random() * probs[-1], side="left") + idx = np.searchsorted(probs, rs.random(size) * probs[-1], side="left") return idx @@ -616,7 +642,7 @@ class SampledBreadthFirstWalk(GraphWalk): It can be used to extract a random sub-graph starting from a set of initial nodes. """ - def run(self, nodes, n_size, n=1, seed=None): + def run(self, nodes, n_size, n=1, seed=None, weighted=False): """ Performs a sampled breadth-first walk starting from the root nodes. @@ -629,6 +655,7 @@ def run(self, nodes, n_size, n=1, seed=None): number of neighbours requested. n (int): Number of walks per node id. seed (int, optional): Random number generator seed; Default is None. + weighted (bool, optional): If True, sample neighbours using the edge weights in the graph. Returns: A list of lists such that each list element is a sequence of ids corresponding to a BFW. @@ -658,18 +685,13 @@ def run(self, nodes, n_size, n=1, seed=None): if depth > max_hops: continue - neighbours = ( - self.graph.neighbor_arrays(cur_node, use_ilocs=True) - if cur_node != -1 - else [] + neighbours = self._sample_neighbours_untyped( + self.graph.neighbor_arrays, + rs, + cur_node, + n_size[cur_depth], + weighted, ) - if len(neighbours) == 0: - # Either node is unconnected or is in directed graph with no out-nodes. - _size = n_size[cur_depth] - neighbours = [-1] * _size - else: - # sample with replacement - neighbours = rs.choice(neighbours, size=n_size[cur_depth]) # add them to the back of the queue q.extend((sampled_node, depth) for sampled_node in neighbours) @@ -773,7 +795,7 @@ def __init__(self, graph, graph_schema=None, seed=None): if not graph.is_directed(): self._raise_error("Graph must be directed") - def run(self, nodes, in_size, out_size, n=1, seed=None): + def run(self, nodes, in_size, out_size, n=1, seed=None, weighted=False): """ Performs a sampled breadth-first walk starting from the root nodes. @@ -784,6 +806,7 @@ def run(self, nodes, in_size, out_size, n=1, seed=None): out_size (int): The number of out-directed nodes to sample with replacement at each depth of the walk. n (int, default 1): Number of walks per node id. seed (int, optional): Random number generator seed; default is None + weighted (bool, optional): If True, sample neighbours using the edge weights in the graph. Returns: @@ -834,8 +857,12 @@ def run(self, nodes, in_size, out_size, n=1, seed=None): if depth > max_hops: continue # get in-nodes - neighbours = self._sample_neighbours( - rs, cur_node, 0, in_size[cur_depth] + neighbours = self._sample_neighbours_untyped( + self.graph.in_node_arrays, + rs, + cur_node, + in_size[cur_depth], + weighted, ) # add them to the back of the queue slot = 2 * cur_slot + 1 @@ -843,8 +870,12 @@ def run(self, nodes, in_size, out_size, n=1, seed=None): [(sampled_node, depth, slot) for sampled_node in neighbours] ) # get out-nodes - neighbours = self._sample_neighbours( - rs, cur_node, 1, out_size[cur_depth] + neighbours = self._sample_neighbours_untyped( + self.graph.out_node_arrays, + rs, + cur_node, + out_size[cur_depth], + weighted, ) # add them to the back of the queue slot = slot + 1 @@ -857,37 +888,6 @@ def run(self, nodes, in_size, out_size, n=1, seed=None): return samples - def _sample_neighbours(self, rs, node, idx, size): - """ - Samples (with replacement) the specified number of nodes - from the directed neighbourhood of the given starting node. - If the neighbourhood is empty, then the result will contain - only None values. - Args: - rs: The random state used for sampling. - node: The starting node. - idx: The index specifying the direction of the - neighbourhood to be sampled: 0 => in-nodes; - 1 => out-nodes. - size: The number of nodes to sample. - Returns: - The fixed-length list of neighbouring nodes (or None values - if the neighbourhood is empty). - """ - if node == -1: - # Non-node, e.g. previously sampled from empty neighbourhood - return [-1] * size - - if idx == 0: - neighbours = self.graph.in_node_arrays(node, use_ilocs=True) - else: - neighbours = self.graph.out_node_arrays(node, use_ilocs=True) - if len(neighbours) == 0: - # Sampling from empty neighbourhood - return [-1] * size - # Sample with replacement - return rs.choice(neighbours, size=size) - def _check_neighbourhood_sizes(self, in_size, out_size): """ Checks that the parameter values are valid or raises ValueError exceptions with a message indicating the diff --git a/stellargraph/mapper/sampled_link_generators.py b/stellargraph/mapper/sampled_link_generators.py index 6b60bc107..1fdd86b3d 100644 --- a/stellargraph/mapper/sampled_link_generators.py +++ b/stellargraph/mapper/sampled_link_generators.py @@ -221,13 +221,17 @@ class GraphSAGELinkGenerator(BatchedLinkGenerator): batch_size (int): Size of batch of links to return. num_samples (list): List of number of neighbour node samples per GraphSAGE layer (hop) to take. seed (int or str), optional: Random seed for the sampling methods. + weighted (bool, optional): If True, sample neighbours using the edge weights in the graph. """ - def __init__(self, G, batch_size, num_samples, seed=None, name=None): + def __init__( + self, G, batch_size, num_samples, seed=None, name=None, weighted=False + ): super().__init__(G, batch_size) self.num_samples = num_samples self.name = name + self.weighted = weighted # Check that there is only a single node type for GraphSAGE if len(self.schema.node_types) > 1: @@ -285,7 +289,7 @@ def get_levels(loc, lsize, samples_per_hop, walks): batch_feats = [] for hns in zip(*head_links): node_samples = self._samplers[batch_num].run( - nodes=hns, n=1, n_size=self.num_samples + nodes=hns, n=1, n_size=self.num_samples, weighted=self.weighted ) nodes_per_hop = get_levels(0, 1, self.num_samples, node_samples) @@ -582,14 +586,25 @@ class DirectedGraphSAGELinkGenerator(BatchedLinkGenerator): out_samples (list): The number of out-node samples per layer (hop) to take. seed (int or str), optional: Random seed for the sampling methods. name, optional: Name of generator. + weighted (bool, optional): If True, sample neighbours using the edge weights in the graph. """ - def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None): + def __init__( + self, + G, + batch_size, + in_samples, + out_samples, + seed=None, + name=None, + weighted=False, + ): super().__init__(G, batch_size) self.in_samples = in_samples self.out_samples = out_samples self._name = name + self.weighted = weighted # Check that there is only a single node type for GraphSAGE if len(self.schema.node_types) > 1: @@ -631,7 +646,11 @@ def sample_features(self, head_links, batch_num): for hns in zip(*head_links): node_samples = self._samplers[batch_num].run( - nodes=hns, n=1, in_size=self.in_samples, out_size=self.out_samples + nodes=hns, + n=1, + in_size=self.in_samples, + out_size=self.out_samples, + weighted=self.weighted, ) # Reshape node samples to sensible format diff --git a/stellargraph/mapper/sampled_node_generators.py b/stellargraph/mapper/sampled_node_generators.py index f503ef844..ecfa16e92 100644 --- a/stellargraph/mapper/sampled_node_generators.py +++ b/stellargraph/mapper/sampled_node_generators.py @@ -200,14 +200,18 @@ class GraphSAGENodeGenerator(BatchedNodeGenerator): batch_size (int): Size of batch to return. num_samples (list): The number of samples per layer (hop) to take. seed (int): [Optional] Random seed for the node sampler. + weighted (bool, optional): If True, sample neighbours using the edge weights in the graph. """ - def __init__(self, G, batch_size, num_samples, seed=None, name=None): + def __init__( + self, G, batch_size, num_samples, seed=None, name=None, weighted=False + ): super().__init__(G, batch_size) self.num_samples = num_samples self.head_node_types = self.schema.node_types self.name = name + self.weighted = weighted # Check that there is only a single node type for GraphSAGE if len(self.head_node_types) > 1: @@ -241,7 +245,7 @@ def sample_features(self, head_nodes, batch_num): for that layer. """ node_samples = self._samplers[batch_num].run( - nodes=head_nodes, n=1, n_size=self.num_samples + nodes=head_nodes, n=1, n_size=self.num_samples, weighted=self.weighted ) # The number of samples for each head node (not including itself) @@ -304,9 +308,19 @@ class DirectedGraphSAGENodeGenerator(BatchedNodeGenerator): in_samples (list): The number of in-node samples per layer (hop) to take. out_samples (list): The number of out-node samples per layer (hop) to take. seed (int): [Optional] Random seed for the node sampler. + weighted (bool, optional): If True, sample neighbours using the edge weights in the graph. """ - def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None): + def __init__( + self, + G, + batch_size, + in_samples, + out_samples, + seed=None, + name=None, + weighted=False, + ): super().__init__(G, batch_size) # TODO Add checks for in- and out-nodes sizes @@ -314,6 +328,7 @@ def __init__(self, G, batch_size, in_samples, out_samples, seed=None, name=None) self.out_samples = out_samples self.head_node_types = self.schema.node_types self.name = name + self.weighted = weighted # Check that there is only a single node type for GraphSAGE if len(self.head_node_types) > 1: @@ -350,7 +365,11 @@ def sample_features(self, head_nodes, batch_num): given the sequence of in/out directions. """ node_samples = self.sampler.run( - nodes=head_nodes, n=1, in_size=self.in_samples, out_size=self.out_samples + nodes=head_nodes, + n=1, + in_size=self.in_samples, + out_size=self.out_samples, + weighted=self.weighted, ) # Reshape node samples to sensible format diff --git a/tests/data/test_breadth_first_walker.py b/tests/data/test_breadth_first_walker.py index d306e8d72..d4a31b5fb 100644 --- a/tests/data/test_breadth_first_walker.py +++ b/tests/data/test_breadth_first_walker.py @@ -18,8 +18,13 @@ import pytest import numpy as np from stellargraph.data.explorer import SampledBreadthFirstWalk -from stellargraph.core.graph import StellarDiGraph -from ..test_utils.graphs import create_test_graph, tree_graph, example_graph_random +from stellargraph.core.graph import StellarGraph, StellarDiGraph +from ..test_utils.graphs import ( + create_test_graph, + tree_graph, + example_graph_random, + weighted_tree, +) def expected_bfw_size(n_size): @@ -539,7 +544,8 @@ def test_fixed_random_seed(self): assert len(w0) == len(w1) assert w0 == w1 - def test_benchmark_bfs_walk(self, benchmark): + @pytest.mark.parametrize("weighted", [False, True]) + def test_benchmark_bfs_walk(self, benchmark, weighted): g = example_graph_random(n_nodes=100, n_edges=500) bfw = SampledBreadthFirstWalk(g) @@ -547,4 +553,11 @@ def test_benchmark_bfs_walk(self, benchmark): n = 5 n_size = [5, 5] - benchmark(lambda: bfw.run(nodes=nodes, n=n, n_size=n_size)) + benchmark(lambda: bfw.run(nodes=nodes, n=n, n_size=n_size, weighted=weighted)) + + def test_weighted(self): + g, checker = weighted_tree() + bfw = SampledBreadthFirstWalk(g) + walks = bfw.run(nodes=[0], n=10, n_size=[20, 20], weighted=True) + + checker(node_id for walk in walks for node_id in walk) diff --git a/tests/data/test_directed_breadth_first_sampler.py b/tests/data/test_directed_breadth_first_sampler.py index 9e6342674..9d3be83a5 100644 --- a/tests/data/test_directed_breadth_first_sampler.py +++ b/tests/data/test_directed_breadth_first_sampler.py @@ -19,7 +19,12 @@ import numpy as np from stellargraph.data.explorer import DirectedBreadthFirstNeighbours from stellargraph.core.graph import StellarDiGraph -from ..test_utils.graphs import create_test_graph, tree_graph, example_graph_random +from ..test_utils.graphs import ( + create_test_graph, + tree_graph, + example_graph_random, + weighted_tree, +) class TestDirectedBreadthFirstNeighbours(object): @@ -229,7 +234,8 @@ def test_three_hops(self): assert len(subgraph[0][13]) == out_size[0] * out_size[1] * in_size[2] assert len(subgraph[0][14]) == out_size[0] * out_size[1] * out_size[2] - def test_benchmark_bfs_walk(self, benchmark): + @pytest.mark.parametrize("weighted", [False, True]) + def test_benchmark_bfs_walk(self, benchmark, weighted): g = example_graph_random(n_nodes=100, n_edges=500, is_directed=True) bfw = DirectedBreadthFirstNeighbours(g) @@ -238,4 +244,18 @@ def test_benchmark_bfs_walk(self, benchmark): in_size = [5, 5] out_size = [5, 5] - benchmark(lambda: bfw.run(nodes=nodes, n=n, in_size=in_size, out_size=out_size)) + benchmark( + lambda: bfw.run( + nodes=nodes, n=n, in_size=in_size, out_size=out_size, weighted=weighted + ) + ) + + def test_weighted(self): + g, checker = weighted_tree(is_directed=True) + + bfw = DirectedBreadthFirstNeighbours(g) + walks = bfw.run( + nodes=[0], n=10, in_size=[20, 20], out_size=[20, 20], weighted=True + ) + + checker(node_id for walk in walks for hop in walk for node_id in hop) diff --git a/tests/mapper/test_directed_node_generator.py b/tests/mapper/test_directed_node_generator.py index ea2bdc108..48c0f61cd 100644 --- a/tests/mapper/test_directed_node_generator.py +++ b/tests/mapper/test_directed_node_generator.py @@ -19,6 +19,8 @@ from stellargraph.mapper import DirectedGraphSAGENodeGenerator from stellargraph.core.graph import StellarDiGraph +from ..test_utils.graphs import weighted_tree + # FIXME (#535): Consider using graph fixtures def create_simple_graph(): @@ -228,3 +230,11 @@ def test_two_hop(self): assert out_features[idx, 0, 0] == 0.0 else: assert False + + def test_weighted(self): + g, checker = weighted_tree(is_directed=True) + + gen = DirectedGraphSAGENodeGenerator(g, 7, [5, 3], [5, 3], weighted=True) + samples = gen.flow([0] * 10) + + checker(node_id for array in samples[0][0] for node_id in array.ravel()) diff --git a/tests/mapper/test_link_mappers.py b/tests/mapper/test_link_mappers.py index e50ba6880..20cf5fdc0 100644 --- a/tests/mapper/test_link_mappers.py +++ b/tests/mapper/test_link_mappers.py @@ -41,6 +41,7 @@ example_graph_random, example_hin_1, repeated_features, + weighted_tree, ) from .. import test_utils @@ -409,6 +410,13 @@ def test_GraphSAGELinkGenerator_unsupervisedSampler_sample_generation(self): with pytest.raises(IndexError): nf, nl = mapper[8] + def test_weighted(self): + g, checker = weighted_tree() + + gen = GraphSAGELinkGenerator(g, 7, [10, 5], weighted=True) + samples = gen.flow([(0, 0)] * 10) + checker(node_id for array in samples[0][0] for node_id in array.ravel()) + class Test_HinSAGELinkGenerator(object): """ @@ -1231,3 +1239,10 @@ def test_unsupervisedSampler_sample_generation(self): == nf[2 * ii + 1].shape == (min(self.batch_size, mapper.data_size), dim, self.n_feat) ) + + def test_weighted(self): + g, checker = weighted_tree(is_directed=True) + + gen = DirectedGraphSAGELinkGenerator(g, 7, [5, 3], [5, 3], weighted=True) + samples = gen.flow([(0, 0)] * 10) + checker(node_id for array in samples[0][0] for node_id in array.ravel()) diff --git a/tests/mapper/test_node_mappers.py b/tests/mapper/test_node_mappers.py index af37ef3e6..c880c24ea 100644 --- a/tests/mapper/test_node_mappers.py +++ b/tests/mapper/test_node_mappers.py @@ -34,6 +34,7 @@ example_hin_1, create_graph_features, repeated_features, + weighted_tree, ) from .. import test_utils @@ -353,6 +354,15 @@ def test_nodemapper_incorrect_targets(): ) +def test_nodemapper_weighted(): + g, checker = weighted_tree() + + gen = GraphSAGENodeGenerator(g, 7, [10, 6], weighted=True) + samples = gen.flow([0] * 10) + + checker(node_id for array in samples[0][0] for node_id in array.ravel()) + + def test_hinnodemapper_constructor(): feature_sizes = {"A": 10, "B": 10} G = example_hin_1(feature_sizes=feature_sizes) diff --git a/tests/test_utils/graphs.py b/tests/test_utils/graphs.py index 83cdd537e..512c50450 100644 --- a/tests/test_utils/graphs.py +++ b/tests/test_utils/graphs.py @@ -375,3 +375,35 @@ def weighted_hin(): ) return StellarGraph(nodes={"A": a, "B": b}, edges={"R": r, "S": s, "T": t, "U": u}) + + +def weighted_tree(is_directed=False): + # a binary tree: + # 0---1--3 + # | | + # | 4 + # 2--5 + # | + # 6 + nodes = repeated_features(range(7), 3) + edges = pd.DataFrame( + { + "source": [0, 0, 1, 1, 2, 2], + "target": [1, 2, 3, 4, 5, 6], + "weight": [1.0, 2, 10, 1, 1, 0], + } + ) + cls = StellarDiGraph if is_directed else StellarGraph + + def check_occurrence(sequence): + from collections import Counter + + occurrence = Counter(sequence) + # 0--2 has higher weight than 0--1 + assert occurrence[2] > occurrence[1] + # 1--3 has higher weight than 1--4 + assert occurrence[3] > occurrence[4] + # 2--6 has 0 weight (i.e. should never be taken) + assert 6 not in occurrence + + return cls(nodes, edges), check_occurrence