diff --git a/tests/python/sparsetir/sparse_tir_lowered_buffer_scripts.py b/tests/python/sparsetir/sparse_tir_lowered_buffer_scripts.py index 2e1879261..8ce6c9489 100644 --- a/tests/python/sparsetir/sparse_tir_lowered_buffer_scripts.py +++ b/tests/python/sparsetir/sparse_tir_lowered_buffer_scripts.py @@ -770,13 +770,14 @@ def fused_reduction_4d_3d( @T.prim_func -def rgcn_forward( +def rgcn_homo_forward( etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, + m: T.int32, n: T.int32, r: T.int32, feat_size: T.int32, @@ -787,16 +788,16 @@ def rgcn_forward( E_data = T.match_buffer(etype, [nnz], dtype="int32", strides=[1]) W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32", strides=[1]) X_data = T.match_buffer(x, [n * feat_size], dtype="float32", strides=[1]) - Y_data = T.match_buffer(y, [n * feat_size], dtype="float32", strides=[1]) - J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1]) + Y_data = T.match_buffer(y, [m * feat_size], dtype="float32", strides=[1]) + J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1]) J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1]) # body # with T.block("root") - for i, fo in T.grid(n, feat_size): - with T.block("rgcn-forward0"): + for i, fo in T.grid(m, feat_size): + with T.block("rgcn-homo-forward0"): vi, vfo = T.axis.remap("SS", [i, fo]) T.reads( - J_indptr_data[0 : n + 1], + J_indptr_data[0 : m + 1], W_data[0 : r * feat_size * feat_size], E_data[0:nnz], X_data[0 : n * feat_size], @@ -805,7 +806,7 @@ def rgcn_forward( T.writes(Y_data[vi * feat_size + vfo]) T.block_attr({"sparse": True}) for j, fi in T.grid(J_indptr_data[vi + 1] - J_indptr_data[vi], feat_size): - with T.block("rgcn-forward1"): + with T.block("rgcn-homo-forward1"): vj = T.axis.reduce(n, j) vfi = T.axis.reduce(feat_size, fi) T.reads( @@ -843,6 +844,7 @@ def rgcn_hetero_forward( indices_i: T.handle, indptr_j: T.handle, indices_j: T.handle, + m: T.int32, n: T.int32, num_rels: T.int32, feat_size: T.int32, @@ -854,7 +856,7 @@ def rgcn_hetero_forward( A_data = T.match_buffer(a, [nnz_j], dtype="float32", strides=[1]) W_data = T.match_buffer(w, [num_rels * feat_size * feat_size], dtype="float32", strides=[1]) X_data = T.match_buffer(x, [n * feat_size], dtype="float32", strides=[1]) - Y_data = T.match_buffer(y, [n * feat_size], dtype="float32", strides=[1]) + Y_data = T.match_buffer(y, [m * feat_size], dtype="float32", strides=[1]) I_indptr_data = T.match_buffer(indptr_i, [num_rels + 1], dtype="int32", strides=[1]) I_indices_data = T.match_buffer(indices_i, [nnz_i], dtype="int32", strides=[1]) J_indptr_data = T.match_buffer(indptr_j, [nnz_i + 1], dtype="int32", strides=[1]) @@ -873,11 +875,11 @@ def rgcn_hetero_forward( X_data[0 : n * feat_size], J_indices_data[0:nnz_j], ) - T.writes(Y_data[0 : n * feat_size]) + T.writes(Y_data[0 : m * feat_size]) T.block_attr({"sparse": True}) for i in T.serial(I_indptr_data[vr + 1] - I_indptr_data[vr]): with T.block("rgcn-hetero-forward1"): - vi = T.axis.spatial(n, i) + vi = T.axis.spatial(m, i) T.reads( J_indptr_data[0 : nnz_i + 1], I_indices_data[vi + I_indptr_data[vr]], @@ -954,7 +956,7 @@ def sparse_softmax( T.writes(TMP_data[vi]) T.block_attr({"sparse": True}) with T.init(): - TMP_data[vi] = T.float32(-100000) + TMP_data[vi] = T.min_value("float32") TMP_data[vi] = T.max(TMP_data[vi], A_data[vj + J_indptr_data[vi]]) for j in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]): with T.block("exp_and_sum0"): diff --git a/tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py b/tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py index 8ef895516..7e3998889 100644 --- a/tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py +++ b/tests/python/sparsetir/sparse_tir_lowered_iter_scripts.py @@ -32,11 +32,11 @@ def csrmm( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(m, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") - J_detach = T.dense_fixed(n, "int32") - K = T.dense_fixed(feat_size, "int32") + I = T.dense_fixed(m, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") + J_detach = T.dense_fixed(n, idtype="int32") + K = T.dense_fixed(feat_size, idtype="int32") A = T.match_sparse_buffer(a, [I, J], dtype="float32") B = T.match_sparse_buffer(b, [J_detach, K], dtype="float32") C = T.match_sparse_buffer(c, [I, K], dtype="float32") @@ -78,11 +78,11 @@ def csrmm_dense_iter( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(m, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") - J_detach = T.dense_fixed(n, "int32") - K = T.dense_fixed(feat_size, "int32") + I = T.dense_fixed(m, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") + J_detach = T.dense_fixed(n, idtype="int32") + K = T.dense_fixed(feat_size, idtype="int32") A = T.match_sparse_buffer(a, [I, J], dtype="float32") B = T.match_sparse_buffer(b, [J_detach, K], dtype="float32") C = T.match_sparse_buffer(c, [I, K], dtype="float32") @@ -126,8 +126,8 @@ def csrmm_dense_iter( def segment_reduce(a: T.handle, b: T.handle, indptr: T.handle, n: T.int32, nnz: T.int32) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - J = T.dense_variable(I, (100, nnz), indptr, "int32") + I = T.dense_fixed(n, idtype="int32") + J = T.dense_variable(I, (100, nnz), indptr, idtype="int32") A = T.match_sparse_buffer(a, [I, J], dtype="float32") B = T.match_sparse_buffer(b, [I], dtype="float32") J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1) @@ -151,6 +151,46 @@ def segment_reduce(a: T.handle, b: T.handle, indptr: T.handle, n: T.int32, nnz: B[vi] = B[vi] + A[vi, vj] +@T.prim_func +def csr_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + nnz: T.int32, +) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) + I = T.dense_fixed(n, idtype="int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (m, nnz), indptr, idtype="int32") + A = T.match_sparse_buffer(a, [I, J], dtype="float32") + B = T.match_sparse_buffer(b, [I], dtype="float32") + J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1) + J_indices = T.match_sparse_buffer(indices, [I, J_dense], dtype="int32") + # body + # with T.block("root") + T.assume_buffer_domain(J_indptr, [0, nnz]) + T.assume_buffer_domain(J_indices, [0, m]) + for i in T.serial(n): + with T.block("csr_reduce0"): + vi = T.axis.spatial(n, i) + T.reads(J_indptr[vi : vi + 2], A[vi, 0:m]) + T.writes(B[vi]) + T.block_attr({"sparse": True}) + for j in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csr_reduce1"): + vj = T.axis.reduce(m, j) + T.reads(A[vi, vj]) + T.writes(B[vi]) + T.block_attr({"sparse": True}) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vi, vj] + + @T.prim_func def bsrmm( a: T.handle, @@ -166,13 +206,13 @@ def bsrmm( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(nb, "int32") - J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (mb, nnzb), indptr, "int32") - J_detach = T.dense_fixed(mb, "int32") - BI = T.dense_fixed(blk, "int32") - BJ = T.dense_fixed(blk, "int32") - F = T.dense_fixed(feat_size, "int32") + I = T.dense_fixed(nb, idtype="int32") + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (mb, nnzb), indptr, idtype="int32") + J_detach = T.dense_fixed(mb, idtype="int32") + BI = T.dense_fixed(blk, idtype="int32") + BJ = T.dense_fixed(blk, idtype="int32") + F = T.dense_fixed(feat_size, idtype="int32") A = T.match_sparse_buffer(a, [I, J, BI, BJ], dtype="float32") B = T.match_sparse_buffer(b, [J_detach, BJ, F], dtype="float32") C = T.match_sparse_buffer(c, [I, BI, F], dtype="float32") @@ -204,43 +244,42 @@ def bsrmm( @T.prim_func -def csr_reduce( +def ellmm( a: T.handle, b: T.handle, - indptr: T.handle, + c: T.handle, indices: T.handle, - n: T.int32, - m: T.int32, - nnz: T.int32, + nb: T.int32, + mb: T.int32, + feat_size: T.int32, + col: T.int32, + blk: T.int32, ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (m, nnz), indptr, "int32") - A = T.match_sparse_buffer(a, [I, J], dtype="float32") - B = T.match_sparse_buffer(b, [I], dtype="float32") - J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1) + I = T.dense_fixed(nb, idtype="int32") + J = T.sparse_fixed(I, (mb, col), indices, idtype="int32", sorted=True) + J_dense = T.dense_fixed(col, idtype="int32") + J_detach = T.dense_fixed(mb, idtype="int32") + F = T.dense_fixed(feat_size, idtype="int32") + BI = T.dense_fixed(blk, idtype="int32") + BJ = T.dense_fixed(blk, idtype="int32") + A = T.match_sparse_buffer(a, [I, J, BI, BJ], dtype="float32") + B = T.match_sparse_buffer(b, [J_detach, BJ, F], dtype="float32") + C = T.match_sparse_buffer(c, [I, BI, F], dtype="float32") J_indices = T.match_sparse_buffer(indices, [I, J_dense], dtype="int32") # body # with T.block("root") - T.assume_buffer_domain(J_indptr, [0, nnz]) - T.assume_buffer_domain(J_indices, [0, m]) - for i in T.serial(n): - with T.block("csr_reduce0"): - vi = T.axis.spatial(n, i) - T.reads(J_indptr[vi : vi + 2], A[vi, 0:m]) - T.writes(B[vi]) + T.assume_buffer_domain(J_indices, [0, mb]) + for i, j, bi, bj, f in T.grid(nb, col, blk, blk, feat_size): + with T.block("ellmm0"): + vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [i, j, bi, bj, f]) + T.reads(A[vi, vj, vbi, vbj], B[J_indices[vi, vj], vbj, vf], J_indices[vi, vj]) + T.writes(C[vi, vbi, vf]) T.block_attr({"sparse": True}) - for j in T.serial(J_indptr[vi + 1] - J_indptr[vi]): - with T.block("csr_reduce1"): - vj = T.axis.reduce(m, j) - T.reads(A[vi, vj]) - T.writes(B[vi]) - T.block_attr({"sparse": True}) - with T.init(): - B[vi] = T.float32(0) - B[vi] = B[vi] + A[vi, vj] + with T.init(): + C[vi, vbi, vf] = T.float32(0) + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[J_indices[vi, vj], vbj, vf] @T.prim_func @@ -255,9 +294,9 @@ def csr_element_wise( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(m, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") + I = T.dense_fixed(m, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") A = T.match_sparse_buffer(a, [I, J], dtype="float32") B = T.match_sparse_buffer(b, [I, J], dtype="float32") J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1) @@ -296,13 +335,13 @@ def hyper_gnn( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - F = T.dense_fixed(feat_size, "int32") - J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (m, nnz), indptr, "int32") - J_detach = T.dense_fixed(m, "int32") - I_T = T.sparse_variable(J_detach, (n, nnz), (indptr_T, indices_T), "int32") - I_T_dense = T.dense_variable(J_detach, (n, nnz), indptr_T, "int32") + I = T.dense_fixed(n, idtype="int32") + F = T.dense_fixed(feat_size, idtype="int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (m, nnz), indptr, idtype="int32") + J_detach = T.dense_fixed(m, idtype="int32") + I_T = T.sparse_variable(J_detach, (n, nnz), (indptr_T, indices_T), idtype="int32", sorted=True) + I_T_dense = T.dense_variable(J_detach, (n, nnz), indptr_T, idtype="int32") X = T.match_sparse_buffer(x, [I, F], dtype="float32") Y = T.match_sparse_buffer(y, [I, F], dtype="float32") J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1) @@ -351,45 +390,6 @@ def hyper_gnn( Y[vi, vf] = Y[vi, vf] + X[I_T_indices[vj, vi_t], vf] -@T.prim_func -def ellmm( - a: T.handle, - b: T.handle, - c: T.handle, - indices: T.handle, - nb: T.int32, - mb: T.int32, - feat_size: T.int32, - col: T.int32, - blk: T.int32, -) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(nb, "int32") - J = T.sparse_fixed(I, (mb, col), indices, "int32") - J_dense = T.dense_fixed(col, "int32") - J_detach = T.dense_fixed(mb, "int32") - F = T.dense_fixed(feat_size, "int32") - BI = T.dense_fixed(blk, "int32") - BJ = T.dense_fixed(blk, "int32") - A = T.match_sparse_buffer(a, [I, J, BI, BJ], dtype="float32") - B = T.match_sparse_buffer(b, [J_detach, BJ, F], dtype="float32") - C = T.match_sparse_buffer(c, [I, BI, F], dtype="float32") - J_indices = T.match_sparse_buffer(indices, [I, J_dense], dtype="int32") - # body - # with T.block("root") - T.assume_buffer_domain(J_indices, [0, mb]) - for i, j, bi, bj, f in T.grid(nb, col, blk, blk, feat_size): - with T.block("ellmm0"): - vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [i, j, bi, bj, f]) - T.reads(A[vi, vj, vbi, vbj], B[J_indices[vi, vj], vbj, vf], J_indices[vi, vj]) - T.writes(C[vi, vbi, vf]) - T.block_attr({"sparse": True}) - with T.init(): - C[vi, vbi, vf] = T.float32(0) - C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[J_indices[vi, vj], vbj, vf] - - @T.prim_func def sddmm( a: T.handle, @@ -404,11 +404,11 @@ def sddmm( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(m, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") - J_detach = T.dense_fixed(n, "int32") - K = T.dense_fixed(feat_size, "int32") + I = T.dense_fixed(m, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") + J_detach = T.dense_fixed(n, idtype="int32") + K = T.dense_fixed(feat_size, idtype="int32") A = T.match_sparse_buffer(a, [I, K], dtype="float32") B = T.match_sparse_buffer(b, [J_detach, K], dtype="float32") C = T.match_sparse_buffer(c, [I, J], dtype="float32") @@ -452,11 +452,11 @@ def fused_sddmm( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(m, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") - J_detach = T.dense_fixed(n, "int32") - K = T.dense_fixed(feat_size, "int32") + I = T.dense_fixed(m, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") + J_detach = T.dense_fixed(n, idtype="int32") + K = T.dense_fixed(feat_size, idtype="int32") A = T.match_sparse_buffer(a, [I, K], dtype="float32") B = T.match_sparse_buffer(b, [J_detach, K], dtype="float32") C = T.match_sparse_buffer(c, [I, J], dtype="float32") @@ -517,11 +517,11 @@ def square_sum( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(M, "int32") - J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") - J_dense = T.dense_variable(I, (N1, nnz_j), indptr_j, "int32") - K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), "int32") - K_dense = T.dense_variable(J, (N2, nnz_k), indptr_k, "int32") + I = T.dense_fixed(M, idtype="int32") + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (N1, nnz_j), indptr_j, idtype="int32") + K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), idtype="int32", sorted=True) + K_dense = T.dense_variable(J, (N2, nnz_k), indptr_k, idtype="int32") A = T.match_sparse_buffer(a, [I, J, K], dtype="float32") B = T.match_sparse_buffer(b, [I], dtype="float32") J_indptr = T.match_sparse_buffer(indptr_j, [I], dtype="int32", extra_storage=1) @@ -575,13 +575,13 @@ def square_sum_two_K( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(M, "int32") - J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") - J_dense = T.dense_variable(I, (N1, nnz_j), indptr_j, "int32") - K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32") - K0_dense = T.dense_variable(J, (N2, nnz_k), indptr_k0, "int32") - K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32") - K1_dense = T.dense_variable(J, (N2, nnz_k), indptr_k1, "int32") + I = T.dense_fixed(M, idtype="int32") + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (N1, nnz_j), indptr_j, idtype="int32") + K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), idtype="int32", sorted=True) + K0_dense = T.dense_variable(J, (N2, nnz_k), indptr_k0, idtype="int32") + K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), idtype="int32", sorted=True) + K1_dense = T.dense_variable(J, (N2, nnz_k), indptr_k1, idtype="int32") A = T.match_sparse_buffer(a, [I, J, K0], dtype="float32") B = T.match_sparse_buffer(b, [I], dtype="float32") J_indptr = T.match_sparse_buffer(indptr_j, [I], dtype="int32", extra_storage=1) @@ -676,10 +676,10 @@ def fused_reduction_4d_2d( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32") - K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32") - L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32") + I = T.dense_fixed(n, idtype="int32") + J = T.dense_variable(I, (32768, nnz_j), indptr_j, idtype="int32") + K = T.dense_variable(J, (32768, nnz_k), indptr_k, idtype="int32") + L = T.dense_variable(K, (32768, nnz_l), indptr_l, idtype="int32") X = T.match_sparse_buffer(x, [I, J, K, L], dtype="float32") Y = T.match_sparse_buffer(y, [I, J], dtype="float32") J_indptr = T.match_sparse_buffer(indptr_j, [I], dtype="int32", extra_storage=1) @@ -730,10 +730,10 @@ def fused_reduction_4d_3d( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32") - K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32") - L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32") + I = T.dense_fixed(n, idtype="int32") + J = T.dense_variable(I, (32768, nnz_j), indptr_j, idtype="int32") + K = T.dense_variable(J, (32768, nnz_k), indptr_k, idtype="int32") + L = T.dense_variable(K, (32768, nnz_l), indptr_l, idtype="int32") X = T.match_sparse_buffer(x, [I, J, K, L], dtype="float32") Y = T.match_sparse_buffer(y, [I, J, K], dtype="float32") J_indptr = T.match_sparse_buffer(indptr_j, [I], dtype="int32", extra_storage=1) @@ -764,13 +764,14 @@ def fused_reduction_4d_3d( @T.prim_func -def rgcn_forward( +def rgcn_homo_forward( etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, + m: T.int32, n: T.int32, r: T.int32, feat_size: T.int32, @@ -778,13 +779,13 @@ def rgcn_forward( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") - J_detach = T.dense_fixed(n, "int32") - R = T.dense_fixed(r, "int32") - F_in = T.dense_fixed(feat_size, "int32") - F_out = T.dense_fixed(feat_size, "int32") + I = T.dense_fixed(m, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") + J_detach = T.dense_fixed(n, idtype="int32") + R = T.dense_fixed(r, idtype="int32") + F_in = T.dense_fixed(feat_size, idtype="int32") + F_out = T.dense_fixed(feat_size, idtype="int32") E = T.match_sparse_buffer(etype, [I, J], dtype="int32") W = T.match_sparse_buffer(w, [R, F_out, F_in], dtype="float32") X = T.match_sparse_buffer(x, [J_detach, F_in], dtype="float32") @@ -795,8 +796,8 @@ def rgcn_forward( # with T.block("root") T.assume_buffer_domain(J_indptr, [0, nnz]) T.assume_buffer_domain(J_indices, [0, n]) - for i, fo in T.grid(n, feat_size): - with T.block("rgcn-forward0"): + for i, fo in T.grid(m, feat_size): + with T.block("rgcn-homo-forward0"): vi, vfo = T.axis.remap("SS", [i, fo]) T.reads( J_indptr[vi : vi + 2], @@ -808,7 +809,7 @@ def rgcn_forward( T.writes(Y[vi, vfo]) T.block_attr({"sparse": True}) for j, fi in T.grid(J_indptr[vi + 1] - J_indptr[vi], feat_size): - with T.block("rgcn-forward1"): + with T.block("rgcn-homo-forward1"): vj = T.axis.reduce(n, j) vfi = T.axis.reduce(feat_size, fi) T.reads( @@ -834,6 +835,7 @@ def rgcn_hetero_forward( indices_i: T.handle, indptr_j: T.handle, indices_j: T.handle, + m: T.int32, n: T.int32, num_rels: T.int32, feat_size: T.int32, @@ -843,11 +845,11 @@ def rgcn_hetero_forward( # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) R = T.dense_fixed(num_rels, idtype="int32") - I = T.sparse_variable(R, (n, nnz_i), (indptr_i, indices_i), idtype="int32") - I_dense = T.dense_variable(R, (n, nnz_i), indptr_i, idtype="int32") - J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), idtype="int32") + I = T.sparse_variable(R, (m, nnz_i), (indptr_i, indices_i), idtype="int32", sorted=True) + I_dense = T.dense_variable(R, (m, nnz_i), indptr_i, idtype="int32") + J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), idtype="int32", sorted=True) J_dense = T.dense_variable(I, (n, nnz_j), indptr_j, idtype="int32") - I_detach = T.dense_fixed(n, idtype="int32") + I_detach = T.dense_fixed(m, idtype="int32") J_detach = T.dense_fixed(n, idtype="int32") F_in = T.dense_fixed(feat_size, idtype="int32") F_out = T.dense_fixed(feat_size, idtype="int32") @@ -862,7 +864,7 @@ def rgcn_hetero_forward( # body # with T.block("root") T.assume_buffer_domain(I_indptr, [0, nnz_i]) - T.assume_buffer_domain(I_indices, [0, n]) + T.assume_buffer_domain(I_indices, [0, m]) T.assume_buffer_domain(J_indptr, [0, nnz_j]) T.assume_buffer_domain(J_indices, [0, n]) for fo, r in T.grid(feat_size, num_rels): @@ -870,18 +872,18 @@ def rgcn_hetero_forward( vfo, vr = T.axis.remap("SS", [fo, r]) T.reads( I_indptr[vr : vr + 2], - J_indptr[vr, 0 : n + 1], - I_indices[vr, 0:n], - A[vr, 0:n, 0:n], + J_indptr[vr, 0 : m + 1], + I_indices[vr, 0:m], + A[vr, 0:m, 0:n], W[vr, vfo, 0:feat_size], X[0:n, 0:feat_size], - J_indices[vr, 0:n, 0:n], + J_indices[vr, 0:m, 0:n], ) - T.writes(Y[0:n, vfo]) + T.writes(Y[0:m, vfo]) T.block_attr({"sparse": True}) for i in T.serial(I_indptr[vr + 1] - I_indptr[vr]): with T.block("rgcn-hetero-forward1"): - vi = T.axis.spatial(n, i) + vi = T.axis.spatial(m, i) T.reads( J_indptr[vr, vi : vi + 2], I_indices[vr, vi], @@ -919,9 +921,9 @@ def sparse_softmax( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(n, "int32") - J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") - J_dense = T.dense_variable(I, (n, nnz), indptr, "int32") + I = T.dense_fixed(n, idtype="int32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32") A = T.match_sparse_buffer(a, [I, J], dtype="float32") B = T.match_sparse_buffer(b, [I, J], dtype="float32") J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1) @@ -945,7 +947,7 @@ def sparse_softmax( T.writes(TMP[vi]) T.block_attr({"sparse": True}) with T.init(): - TMP[vi] = T.float32(-100000) + TMP[vi] = T.min_value("float32") TMP[vi] = T.max(TMP[vi], A[vi, vj]) for j in T.serial(J_indptr[vi + 1] - J_indptr[vi]): with T.block("exp_and_sum0"): @@ -983,14 +985,16 @@ def csr2bsr( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1}) - I = T.dense_fixed(m_in, "int32") - J = T.sparse_variable(I, (n_in, nnz_in), (indptr_in, indices_in), "int32") - J_dense = T.dense_variable(I, (n_in, nnz_in), indptr_in, "int32") - I_bsr = T.dense_fixed(m_out, "int32") - J_bsr = T.sparse_variable(I_bsr, (n_out, nnz_out), (indptr_out, indices_out), "int32") - J_bsr_dense = T.dense_variable(I_bsr, (n_out, nnz_out), indptr_out, "int32") - BI = T.dense_fixed(blk_size, "int32") - BJ = T.dense_fixed(blk_size, "int32") + I = T.dense_fixed(m_in, idtype="int32") + J = T.sparse_variable(I, (n_in, nnz_in), (indptr_in, indices_in), idtype="int32", sorted=True) + J_dense = T.dense_variable(I, (n_in, nnz_in), indptr_in, idtype="int32") + I_bsr = T.dense_fixed(m_out, idtype="int32") + J_bsr = T.sparse_variable( + I_bsr, (n_out, nnz_out), (indptr_out, indices_out), idtype="int32", sorted=True + ) + J_bsr_dense = T.dense_variable(I_bsr, (n_out, nnz_out), indptr_out, idtype="int32") + BI = T.dense_fixed(blk_size, idtype="int32") + BJ = T.dense_fixed(blk_size, idtype="int32") A = T.match_sparse_buffer(a, [I, J], dtype="float32") B = T.match_sparse_buffer(b, [I_bsr, J_bsr, BI, BJ], dtype="float32") J_indptr = T.match_sparse_buffer(indptr_in, [I], dtype="int32", extra_storage=1) diff --git a/tests/python/sparsetir/sparse_tir_scripts.py b/tests/python/sparsetir/sparse_tir_scripts.py index 40e5c12d7..cb2bde2f9 100644 --- a/tests/python/sparsetir/sparse_tir_scripts.py +++ b/tests/python/sparsetir/sparse_tir_scripts.py @@ -429,20 +429,21 @@ def fused_reduction_4d_3d( @T.prim_func -def rgcn_forward( +def rgcn_homo_forward( etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, + m: T.int32, n: T.int32, r: T.int32, feat_size: T.int32, nnz: T.int32, ): T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2}) - I = T.dense_fixed(n) + I = T.dense_fixed(m) J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") J_detach = T.dense_fixed(n) R = T.dense_fixed(r) @@ -452,7 +453,7 @@ def rgcn_forward( W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") X = T.match_sparse_buffer(x, (J_detach, F_in), "float32") Y = T.match_sparse_buffer(y, (I, F_out), "float32") - with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [ + with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-homo-forward") as [ i, fo, j, @@ -498,42 +499,6 @@ def rgcn_hetero_forward( Y[i, fo] = Y[i, fo] + A[r, i, j] * W[r, fo, fi] * X[j, fi] -@T.prim_func -def rgcn_hetero_forward_2( - w: T.handle, - x: T.handle, - y: T.handle, - etypes: T.handle, - indptr_i: T.handle, - indices_i: T.handle, - indptr_j: T.handle, - indices_j: T.handle, - n: T.int32, - num_rels: T.int32, - group: T.int32, - feat_size: T.int32, - nnz_i: T.int32, - nnz_j: T.int32, -): - T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2}) - R = T.dense_fixed(num_rels) - G = T.dense_fixed(group) - I = T.sparse_variable(G, (n, nnz_i), (indptr_i, indices_i), "int32") - J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), "int32") - I_detach = T.dense_fixed(n) - J_detach = T.dense_fixed(n) - F_in = T.dense_fixed(feat_size) - F_out = T.dense_fixed(feat_size) - W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") - X = T.match_sparse_buffer(x, (J_detach, F_in), "float32") - Y = T.match_sparse_buffer(y, (I_detach, F_out), "float32") - E = T.match_sparse_buffer(etypes, (G,), "int32") - with T.iter([F_out, G, I, J, F_in], "SSSRR", "rgcn-hetero-forward") as [fo, g, i, j, fi]: - with T.init(): - Y[i, fo] = 0. - Y[i, fo] = Y[i, fo] + W[E[g], fo, fi] * X[j, fi] - - @T.prim_func def sparse_softmax( a: T.handle, @@ -553,7 +518,7 @@ def sparse_softmax( with T.iter([I], "S", "sparse_softmax") as [i]: with T.iter([J], "R", "computer_max") as [j]: with T.init(): - TMP[i] = T.float32(-100000) + TMP[i] = T.min_value("float32") TMP[i] = T.max(TMP[i], A[i, j]) with T.iter([J], "R", "exp_and_sum") as [j]: with T.init(): diff --git a/tests/python/sparsetir/test_tir_sparse_lower_buffer.py b/tests/python/sparsetir/test_tir_sparse_lower_buffer.py index 4b2dd4f3e..6ee2d9630 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower_buffer.py +++ b/tests/python/sparsetir/test_tir_sparse_lower_buffer.py @@ -39,7 +39,7 @@ "square_sum_two_K", "fused_reduction_4d_2d", "fused_reduction_4d_3d", - "rgcn_forward", + "rgcn_homo_forward", "rgcn_hetero_forward", "sparse_softmax", "csr2bsr", diff --git a/tests/python/sparsetir/test_tir_sparse_lower_iter.py b/tests/python/sparsetir/test_tir_sparse_lower_iter.py index 711c85ac8..ee81cca60 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower_iter.py +++ b/tests/python/sparsetir/test_tir_sparse_lower_iter.py @@ -38,7 +38,7 @@ "square_sum_two_K", "fused_reduction_4d_2d", "fused_reduction_4d_3d", - "rgcn_forward", + "rgcn_homo_forward", "rgcn_hetero_forward", "sparse_softmax", "csr2bsr",