Skip to content

Commit

Permalink
add bmm test back
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Nov 18, 2022
1 parent f170bc5 commit bac8f2b
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 36 deletions.
64 changes: 64 additions & 0 deletions tests/python/sparsetir/sparse_tir_lowered_buffer_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,70 @@ def hyper_gnn(
)


@T.prim_func
def bmm(
x: T.handle,
y: T.handle,
z: T.handle,
indptr_i: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_ij: T.handle,
indptr_ik: T.handle,
indptr_jk: T.handle,
batch_size: T.int32,
nnz_i: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_ij: T.int32,
nnz_jk: T.int32,
nnz_ik: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
X_data = T.match_buffer(x, [nnz_ij], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [nnz_jk], dtype="float32", strides=[1])
Z_data = T.match_buffer(z, [nnz_ik], dtype="float32", strides=[1])
I_indptr_data = T.match_buffer(indptr_i, [batch_size + 1], dtype="int32", strides=[1])
J_indptr_data = T.match_buffer(indptr_j, [batch_size + 1], dtype="int32", strides=[1])
K_indptr_data = T.match_buffer(indptr_k, [batch_size + 1], dtype="int32", strides=[1])
IJ_indptr_data = T.match_buffer(indptr_ij, [batch_size + 1], dtype="int32", strides=[1])
IK_indptr_data = T.match_buffer(indptr_ik, [batch_size + 1], dtype="int32", strides=[1])
JK_indptr_data = T.match_buffer(indptr_jk, [batch_size + 1], dtype="int32", strides=[1])
# body
# with T.block("root")
for b in T.serial(batch_size):
with T.block("bmm0"):
vb = T.axis.spatial(batch_size, b)
T.reads(
I_indptr_data[0 : batch_size + 1],
J_indptr_data[0 : batch_size + 1],
K_indptr_data[0 : batch_size + 1],
X_data[0:nnz_ij],
Y_data[0:nnz_jk],
)
T.writes(Z_data[0:nnz_ik])
T.block_attr({"sparse": True})
for i, j, k in T.grid(
I_indptr_data[vb + 1] - I_indptr_data[vb],
J_indptr_data[vb + 1] - J_indptr_data[vb],
K_indptr_data[vb + 1] - K_indptr_data[vb],
):
with T.block("bmm1"):
vi = T.axis.spatial(32768, i)
vj = T.axis.reduce(32768, j)
vk = T.axis.spatial(32768, k)
T.reads(X_data[vj + IJ_indptr_data[vb]], Y_data[vk + JK_indptr_data[vb]])
T.writes(Z_data[vk + IK_indptr_data[vb]])
T.block_attr({"sparse": True})
with T.init():
Z_data[vk + IK_indptr_data[vb]] = T.float32(0)
Z_data[vk + IK_indptr_data[vb]] = (
Z_data[vk + IK_indptr_data[vb]]
+ X_data[vj + IJ_indptr_data[vb]] * Y_data[vk + JK_indptr_data[vb]]
)


@T.prim_func
def sddmm(
a: T.handle,
Expand Down
74 changes: 74 additions & 0 deletions tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,80 @@ def hyper_gnn(
Y[vi, vf] = Y[vi, vf] + X[I_T_indices[vj, vi_t], vf]


@T.prim_func
def bmm(
x: T.handle,
y: T.handle,
z: T.handle,
indptr_i: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_ij: T.handle,
indptr_ik: T.handle,
indptr_jk: T.handle,
batch_size: T.int32,
nnz_i: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_ij: T.int32,
nnz_jk: T.int32,
nnz_ik: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1})
B = T.dense_fixed(batch_size, idtype="int32")
I = T.dense_variable(B, (32768, nnz_i), indptr_i, idtype="int32")
J = T.dense_variable(B, (32768, nnz_j), indptr_j, idtype="int32")
K = T.dense_variable(B, (32768, nnz_k), indptr_k, idtype="int32")
IJ = T.dense_variable(B, (32768, nnz_ij), indptr_ij, idtype="int32")
JK = T.dense_variable(B, (32768, nnz_jk), indptr_jk, idtype="int32")
IK = T.dense_variable(B, (32768, nnz_ik), indptr_ik, idtype="int32")
X = T.match_sparse_buffer(x, [B, I, IJ], dtype="float32")
Y = T.match_sparse_buffer(y, [B, J, JK], dtype="float32")
Z = T.match_sparse_buffer(z, [B, I, IK], dtype="float32")
I_indptr = T.match_sparse_buffer(indptr_i, [B], dtype="int32", extra_storage=1)
J_indptr = T.match_sparse_buffer(indptr_j, [B], dtype="int32", extra_storage=1)
K_indptr = T.match_sparse_buffer(indptr_k, [B], dtype="int32", extra_storage=1)
IJ_indptr = T.match_sparse_buffer(indptr_ij, [B], dtype="int32", extra_storage=1)
IK_indptr = T.match_sparse_buffer(indptr_ik, [B], dtype="int32", extra_storage=1)
JK_indptr = T.match_sparse_buffer(indptr_jk, [B], dtype="int32", extra_storage=1)
# body
# with T.block("root")
T.assume_buffer_domain(I_indptr, [0, nnz_i])
T.assume_buffer_domain(J_indptr, [0, nnz_j])
T.assume_buffer_domain(K_indptr, [0, nnz_k])
T.assume_buffer_domain(IJ_indptr, [0, nnz_ij])
T.assume_buffer_domain(JK_indptr, [0, nnz_jk])
T.assume_buffer_domain(IK_indptr, [0, nnz_ik])
for b in T.serial(batch_size):
with T.block("bmm0"):
vb = T.axis.spatial(batch_size, b)
T.reads(
I_indptr[vb : vb + 2],
J_indptr[vb : vb + 2],
K_indptr[vb : vb + 2],
X[vb, 0:32768, 0:32768],
Y[vb, 0:32768, 0:32768],
)
T.writes(Z[vb, 0:32768, 0:32768])
T.block_attr({"sparse": True})
for i, j, k in T.grid(
I_indptr[vb + 1] - I_indptr[vb],
J_indptr[vb + 1] - J_indptr[vb],
K_indptr[vb + 1] - K_indptr[vb],
):
with T.block("bmm1"):
vi = T.axis.spatial(32768, i)
vj = T.axis.reduce(32768, j)
vk = T.axis.spatial(32768, k)
T.reads(X[vb, vi, vj], Y[vb, vj, vk])
T.writes(Z[vb, vi, vk])
T.block_attr({"sparse": True})
with T.init():
Z[vb, vi, vk] = T.float32(0)
Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vj] * Y[vb, vj, vk]


@T.prim_func
def sddmm(
a: T.handle,
Expand Down
68 changes: 34 additions & 34 deletions tests/python/sparsetir/sparse_tir_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,40 +228,40 @@ def hyper_gnn(
Y[i, f] = Y[i, f] + X[i_t, f]


# @T.prim_func
# def bmm(
# x: T.handle,
# y: T.handle,
# z: T.handle,
# indptr_i: T.handle,
# indptr_j: T.handle,
# indptr_k: T.handle,
# offset_ij: T.handle,
# offset_jk: T.handle,
# offset_ik: T.handle,
# batch_size: T.int32,
# nnz_i: T.int32,
# nnz_j: T.int32,
# nnz_k: T.int32,
# nnz_ij: T.int32,
# nnz_jk: T.int32,
# nnz_ik: T.int32,
# ) -> None:
# T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2})
# B = T.dense_fixed(batch_size)
# I = T.dense_variable(B, (32768, nnz_i), indptr_i, "int32")
# J = T.dense_variable(B, (32768, nnz_j), indptr_j, "int32")
# K = T.dense_variable(B, (32768, nnz_k), indptr_k, "int32")
# (IJ,) = T.flatten([I, J], nnz_ij, (offset_ij,))
# (JK,) = T.flatten([J, K], nnz_jk, (offset_jk,))
# (IK,) = T.flatten([I, K], nnz_ik, (offset_ik,))
# X = T.match_sparse_buffer(x, (B, IJ, J), "float32")
# Y = T.match_sparse_buffer(y, (B, JK, K), "float32")
# Z = T.match_sparse_buffer(z, (B, IK, K), "float32")
# with T.iter([B, I, J, K], "SSRS", "bmm") as [vb, vi, vj, vk]:
# with T.init():
# Z[vb, vi, vk] = 0.0
# Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vj] * Y[vb, vj, vk]
@T.prim_func
def bmm(
x: T.handle,
y: T.handle,
z: T.handle,
indptr_i: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_ij: T.handle,
indptr_ik: T.handle,
indptr_jk: T.handle,
batch_size: T.int32,
nnz_i: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_ij: T.int32,
nnz_jk: T.int32,
nnz_ik: T.int32,
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2})
B = T.dense_fixed(batch_size)
I = T.dense_variable(B, (32768, nnz_i), indptr_i, "int32")
J = T.dense_variable(B, (32768, nnz_j), indptr_j, "int32")
K = T.dense_variable(B, (32768, nnz_k), indptr_k, "int32")
IJ = T.dense_variable(B, (32768, nnz_ij), indptr_ij, "int32")
JK = T.dense_variable(B, (32768, nnz_jk), indptr_jk, "int32")
IK = T.dense_variable(B, (32768, nnz_ik), indptr_ik, "int32")
X = T.match_sparse_buffer(x, (B, I, IJ), "float32")
Y = T.match_sparse_buffer(y, (B, J, JK), "float32")
Z = T.match_sparse_buffer(z, (B, I, IK), "float32")
with T.iter([B, I, J, K], "SSRS", "bmm") as [b, i, j, k]:
with T.init():
Z[b, i, k] = 0.0
Z[b, i, k] = Z[b, i, k] + X[b, i, j] * Y[b, j, k]


@T.prim_func
Expand Down
2 changes: 1 addition & 1 deletion tests/python/sparsetir/test_tir_sparse_lower_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"ellmm",
"csr_element_wise",
"hyper_gnn",
# "bmm",
"bmm",
"sddmm",
"fused_sddmm",
"square_sum",
Expand Down
2 changes: 1 addition & 1 deletion tests/python/sparsetir/test_tir_sparse_lower_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"ellmm",
"csr_element_wise",
"hyper_gnn",
# "bmm",
"bmm",
"sddmm",
"fused_sddmm",
"square_sum",
Expand Down

0 comments on commit bac8f2b

Please sign in to comment.