Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NumPy randomness for random walks, not Python #1666

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions stellargraph/data/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, graph, seed=None):
raise TypeError("Graph must be a StellarGraph or StellarDiGraph.")

self.graph = graph
self._random_state, self._np_random_state = random_state(seed)
_, self._np_random_state = random_state(seed)

def _get_random_state(self, seed):
"""
Expand All @@ -72,10 +72,10 @@ def _get_random_state(self, seed):
"""
if seed is None:
# Restore the random state
return self._random_state
return self._np_random_state
# seed the random number generator
require_integer_in_range(seed, "seed", min_val=0)
rs, _ = random_state(seed)
_, rs = random_state(seed)
return rs

@staticmethod
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, graph, graph_schema=None, seed=None):

# Initialize the random state
self._check_seed(seed)
self._random_state, self._np_random_state = random_state(seed)
_, self._np_random_state = random_state(seed)

# We require a StellarGraph for this
if not isinstance(graph, StellarGraph):
Expand Down Expand Up @@ -154,9 +154,9 @@ def _get_random_state(self, seed):
"""
if seed is None:
# Use the class's random state
return self._random_state
return self._np_random_state
# seed the random number generator
rs, _ = random_state(seed)
_, rs = random_state(seed)
return rs

def neighbors(self, node):
Expand Down Expand Up @@ -669,7 +669,7 @@ def run(self, nodes, n_size, n=1, seed=None):
neighbours = [-1] * _size
else:
# sample with replacement
neighbours = rs.choices(neighbours, k=n_size[cur_depth])
neighbours = rs.choice(neighbours, size=n_size[cur_depth])

# add them to the back of the queue
q.extend((sampled_node, depth) for sampled_node in neighbours)
Expand Down Expand Up @@ -721,7 +721,7 @@ def run(self, nodes, n_size, n=1, seed=None):
q.extend([(node, node_type, 0)])

# add the root node to the walks
walk.append([node])
walk.append(np.array([node]))
while len(q) > 0:
# remove the top element in the queue and pop the item from the front of the list
frontier = q.pop(0)
Expand All @@ -743,10 +743,10 @@ def run(self, nodes, n_size, n=1, seed=None):
# In case of no neighbours of the current node for et, neigh_et == [None],
# and samples automatically becomes [None]*n_size[depth-1]
if len(neigh_et) > 0:
samples = rs.choices(neigh_et, k=n_size[depth - 1])
samples = rs.choice(neigh_et, size=n_size[depth - 1])
else: # this doesn't happen anymore, see the comment above
_size = n_size[depth - 1]
samples = [-1] * _size
samples = np.full(_size, -1)

walk.append(samples)
q.extend(
Expand Down Expand Up @@ -886,7 +886,7 @@ def _sample_neighbours(self, rs, node, idx, size):
# Sampling from empty neighbourhood
return [-1] * size
# Sample with replacement
return rs.choices(neighbours, k=size)
return rs.choice(neighbours, size=size)

def _check_neighbourhood_sizes(self, in_size, out_size):
"""
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def run(
f"max_walk_length: maximum walk length should not be less than the context window size, found {max_walk_length}"
)

np_rs = self._np_random_state if seed is None else np.random.RandomState(seed)
np_rs = self._get_random_state(seed)
walks = []
num_cw_curr = 0

Expand Down
12 changes: 6 additions & 6 deletions stellargraph/mapper/sampled_link_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,19 +439,19 @@ def sample_features(self, head_links, batch_num):
[
(
nt,
reduce(
operator.concat,
(samples[ks] for samples in node_samples for ks in indices),
[],
),
np.concatenate(
[samples[ks] for samples in node_samples for ks in indices]
)
if indices
else np.array([], dtype=np.uint8),
)
for nt, indices in self._sampling_schema[ii]
]
)

# Interlace the two lists, nodes_by_type[0] (for src head nodes) and nodes_by_type[1] (for dst head nodes)
nodes_by_type = [
tuple((ab[0][0], reduce(operator.concat, (ab[0][1], ab[1][1]))))
(ab[0][0], np.concatenate([ab[0][1], ab[1][1]]))
for ab in zip(nodes_by_type[0], nodes_by_type[1])
]

Expand Down
10 changes: 5 additions & 5 deletions stellargraph/mapper/sampled_node_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,11 @@ def sample_features(self, head_nodes, batch_num):
nodes_by_type = [
(
nt,
reduce(
operator.concat,
(samples[ks] for samples in node_samples for ks in indices),
[],
),
np.concatenate(
[samples[ks] for samples in node_samples for ks in indices]
)
if indices
else np.array([], dtype=np.uint8),
)
for nt, indices in self._sampling_schema[0]
]
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_heterogeneous_breadth_first_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_walk_generation_single_root_node_self_loner(self):
assert len(subgraphs) == 1
assert len(subgraphs[0]) == 9
for level in subgraphs[0]:
assert type(level) == list
assert type(level) == np.ndarray
if len(level) > 0:
# All values should be rood_node_id or None
for value in level:
Expand All @@ -219,7 +219,7 @@ def test_walk_generation_single_root_node_self_loner(self):
)

for level in subgraphs2[0]:
assert type(level) == list
assert type(level) == np.ndarray
if len(level) > 0:
for value in level:
assert (value == nodes[0]) or (value == -1)
Expand Down