Skip to content

Commit

Permalink
Merge branch 'master' into use_k8s
Browse files Browse the repository at this point in the history
  • Loading branch information
VoVAllen authored May 31, 2021
2 parents e5eda57 + 2ad7a9e commit 19d3bce
Showing 2 changed files with 53 additions and 18 deletions.
4 changes: 2 additions & 2 deletions python/dgl/batch.py
Original file line number Diff line number Diff line change
@@ -168,9 +168,9 @@ def batch(graphs, ndata=ALL, edata=ALL, *,
if any(g.is_block for g in graphs):
raise DGLError("Batching a MFG is not supported.")

relations = list(sorted(graphs[0].canonical_etypes))
relations = list(graphs[0].canonical_etypes)
relation_ids = [graphs[0].get_etype_id(r) for r in relations]
ntypes = list(sorted(graphs[0].ntypes))
ntypes = list(graphs[0].ntypes)
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations]

67 changes: 51 additions & 16 deletions tests/compute/test_batched_heterograph.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
from utils import parametrize_dtype
from test_utils import check_graph_equal, get_cases


def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
assert g1.ntypes == g2.ntypes
assert g1.etypes == g2.etypes
@@ -32,14 +33,17 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
if g1.number_of_nodes(nty) == 0:
continue
for feat_name in node_attrs[nty]:
assert F.allclose(g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name])
assert F.allclose(
g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name])

if edge_attrs is not None:
for ety in edge_attrs.keys():
if g1.number_of_edges(ety) == 0:
continue
for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
assert F.allclose(
g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])


@pytest.mark.parametrize('gs', get_cases(['two_hetero_batch']))
@parametrize_dtype
@@ -63,7 +67,7 @@ def test_topology(gs, idtype):
assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]
assert bg.number_of_nodes(ntype) == (
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))

# Test number of edges
for etype in bg.canonical_etypes:
@@ -74,7 +78,8 @@ def test_topology(gs, idtype):

# Test relabeled nodes
for ntype in bg.ntypes:
assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype)))
assert list(F.asnumpy(bg.nodes(ntype))) == list(
range(bg.number_of_nodes(ntype)))

# Test relabeled edges
src, dst = bg.edges(etype=('user', 'follows', 'user'))
@@ -104,6 +109,7 @@ def test_topology(gs, idtype):
bg_local = bg.local_var()
assert bg.batch_size == bg_local.batch_size


@parametrize_dtype
def test_batching_batched(idtype):
"""Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
@@ -133,18 +139,19 @@ def test_batching_batched(idtype):
assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [
g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype)]
assert bg2.number_of_nodes(ntype) == (
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype))
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype))

# Test number of edges
for etype in bg2.canonical_etypes:
assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [
g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)]
assert bg2.number_of_edges(etype) == (
g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))
g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))

# Test relabeled nodes
for ntype in bg2.ntypes:
assert list(F.asnumpy(bg2.nodes(ntype))) == list(range(bg2.number_of_nodes(ntype)))
assert list(F.asnumpy(bg2.nodes(ntype))) == list(
range(bg2.number_of_nodes(ntype)))

# Test relabeled edges
src, dst = bg2.edges(etype='follows')
@@ -160,6 +167,7 @@ def test_batching_batched(idtype):
check_equivalence_between_heterographs(g2, g5)
check_equivalence_between_heterographs(g3, g6)


@parametrize_dtype
def test_features(idtype):
"""Test the features of batched DGLHeteroGraphs"""
@@ -233,6 +241,7 @@ def test_features(idtype):
bg = dgl.batch([g1, g2], edge_attrs=['h1'])
assert 'h2' not in bg.edges['follows'].data.keys()


@unittest.skipIf(F.backend_name == 'mxnet', reason="MXNet does not support split array with zero-length segment.")
@parametrize_dtype
def test_empty_relation(idtype):
@@ -279,7 +288,8 @@ def test_empty_relation(idtype):
assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2'])
assert F.allclose(bg.edges['follows'].data['h1'],
F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1'])
assert F.allclose(bg.edges['plays'].data['h1'],
g2.edges['plays'].data['h1'])

# Test unbatching graphs
g3, g4 = dgl.unbatch(bg)
@@ -297,6 +307,7 @@ def test_empty_relation(idtype):
g2 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 1, 'v': 5})
dgl.batch([g1, g2])


@parametrize_dtype
def test_unbatch2(idtype):
# batch 3 graphs but unbatch to 2
@@ -321,22 +332,46 @@ def test_unbatch2(idtype):
check_graph_equal(g2, gg2)
check_graph_equal(g3, gg3)


@parametrize_dtype
def test_batch_keeps_empty_data(idtype):
g1 = dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
g1 = dgl.heterograph({("a", "to", "a"): ([], [])}
).astype(idtype).to(F.ctx())
g1.nodes["a"].data["nh"] = F.tensor([])
g1.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g2 = dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
g1.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g2 = dgl.heterograph({("a", "to", "a"): ([], [])}
).astype(idtype).to(F.ctx())
g2.nodes["a"].data["nh"] = F.tensor([])
g2.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g2.edges[("a", "to", "a")].data["eh"] = F.tensor([])
g = dgl.batch([g1, g2])
assert "nh" in g.nodes["a"].data
assert "eh" in g.edges[("a", "to", "a")].data


@unittest.skipIf(F._default_context_str == 'gpu', reason="Issue is not related with GPU")
def test_batch_netypes():
# Test for https://github.com/dmlc/dgl/issues/2808
import networkx as nx
B = nx.DiGraph()
B.add_nodes_from([1, 2, 3, 4], bipartite=0,
some_attr=F.tensor([1, 2, 3, 4], dtype=F.float32))
B.add_nodes_from(["a", "b", "c"], bipartite=1)
B.add_edges_from([(1, "a"), (1, "b"), (2, "b"),
(2, "c"), (3, "c"), (4, "a")])

g_dict = {0: dgl.bipartite_from_networkx(B, 'A', 'e', 'B'),
1: dgl.bipartite_from_networkx(B, 'B', 'e', 'A'),
2: dgl.bipartite_from_networkx(B, 'A', 'e', 'B', u_attrs=['some_attr']),
3: dgl.bipartite_from_networkx(B, 'B', 'e', 'A', u_attrs=['some_attr'])
}
for _, g in g_dict.items():
dgl.batch((g, g, g))


if __name__ == '__main__':
#test_topology('int32')
#test_batching_batched('int32')
#test_batched_features('int32')
# test_topology('int32')
# test_batching_batched('int32')
# test_batched_features('int32')
# test_empty_relation('int64')
#test_to_device('int32')
# test_to_device('int32')
pass

0 comments on commit 19d3bce

Please sign in to comment.