Skip to content

Commit

Permalink
fix lowering test
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Nov 13, 2022
1 parent 574eb90 commit 8ae7b13
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 212 deletions.
24 changes: 13 additions & 11 deletions tests/python/sparsetir/sparse_tir_lowered_buffer_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,13 +770,14 @@ def fused_reduction_4d_3d(


@T.prim_func
def rgcn_forward(
def rgcn_homo_forward(
etype: T.handle,
w: T.handle,
x: T.handle,
y: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
r: T.int32,
feat_size: T.int32,
Expand All @@ -787,16 +788,16 @@ def rgcn_forward(
E_data = T.match_buffer(etype, [nnz], dtype="int32", strides=[1])
W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32", strides=[1])
X_data = T.match_buffer(x, [n * feat_size], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [n * feat_size], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1])
Y_data = T.match_buffer(y, [m * feat_size], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for i, fo in T.grid(n, feat_size):
with T.block("rgcn-forward0"):
for i, fo in T.grid(m, feat_size):
with T.block("rgcn-homo-forward0"):
vi, vfo = T.axis.remap("SS", [i, fo])
T.reads(
J_indptr_data[0 : n + 1],
J_indptr_data[0 : m + 1],
W_data[0 : r * feat_size * feat_size],
E_data[0:nnz],
X_data[0 : n * feat_size],
Expand All @@ -805,7 +806,7 @@ def rgcn_forward(
T.writes(Y_data[vi * feat_size + vfo])
T.block_attr({"sparse": True})
for j, fi in T.grid(J_indptr_data[vi + 1] - J_indptr_data[vi], feat_size):
with T.block("rgcn-forward1"):
with T.block("rgcn-homo-forward1"):
vj = T.axis.reduce(n, j)
vfi = T.axis.reduce(feat_size, fi)
T.reads(
Expand Down Expand Up @@ -843,6 +844,7 @@ def rgcn_hetero_forward(
indices_i: T.handle,
indptr_j: T.handle,
indices_j: T.handle,
m: T.int32,
n: T.int32,
num_rels: T.int32,
feat_size: T.int32,
Expand All @@ -854,7 +856,7 @@ def rgcn_hetero_forward(
A_data = T.match_buffer(a, [nnz_j], dtype="float32", strides=[1])
W_data = T.match_buffer(w, [num_rels * feat_size * feat_size], dtype="float32", strides=[1])
X_data = T.match_buffer(x, [n * feat_size], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [n * feat_size], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [m * feat_size], dtype="float32", strides=[1])
I_indptr_data = T.match_buffer(indptr_i, [num_rels + 1], dtype="int32", strides=[1])
I_indices_data = T.match_buffer(indices_i, [nnz_i], dtype="int32", strides=[1])
J_indptr_data = T.match_buffer(indptr_j, [nnz_i + 1], dtype="int32", strides=[1])
Expand All @@ -873,11 +875,11 @@ def rgcn_hetero_forward(
X_data[0 : n * feat_size],
J_indices_data[0:nnz_j],
)
T.writes(Y_data[0 : n * feat_size])
T.writes(Y_data[0 : m * feat_size])
T.block_attr({"sparse": True})
for i in T.serial(I_indptr_data[vr + 1] - I_indptr_data[vr]):
with T.block("rgcn-hetero-forward1"):
vi = T.axis.spatial(n, i)
vi = T.axis.spatial(m, i)
T.reads(
J_indptr_data[0 : nnz_i + 1],
I_indices_data[vi + I_indptr_data[vr]],
Expand Down Expand Up @@ -954,7 +956,7 @@ def sparse_softmax(
T.writes(TMP_data[vi])
T.block_attr({"sparse": True})
with T.init():
TMP_data[vi] = T.float32(-100000)
TMP_data[vi] = T.min_value("float32")
TMP_data[vi] = T.max(TMP_data[vi], A_data[vj + J_indptr_data[vi]])
for j in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("exp_and_sum0"):
Expand Down
Loading

0 comments on commit 8ae7b13

Please sign in to comment.