diff --git a/tests/python/sparsetir/lowered_tir.py b/tests/python/sparsetir/lowered_tir.py deleted file mode 100644 index 38b7807da..000000000 --- a/tests/python/sparsetir/lowered_tir.py +++ /dev/null @@ -1,825 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Lowered TIR scripts of sparse workloads.""" -from tvm.script import tir as T - - -@T.prim_func -def lowered_csrmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - m: T.int32, - n: T.int32, - k: T.int32, - nnz: T.int32, -) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, (nnz,), "float32") - B_data = T.match_buffer(b, (n * k,), "float32") - C_data = T.match_buffer(c, (m * k,), "float32") - J_indptr = T.match_buffer(indptr, (m + 1,), "int32") - J_indices = T.match_buffer(indices, (nnz,), "int32") - # body - # with T.block("root") - for v_vi, v_vk in T.grid(m, k): - with T.block("csrmm0"): - vi, vk = T.axis.remap("SS", [v_vi, v_vk]) - T.reads( - J_indptr[0 : m + 1], - J_indices[0:nnz], - A_data[0:nnz], - B_data[0 : n * k], - C_data[0 : m * k], - ) - T.writes(C_data[0 : m * k]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): - with T.block("csrmm1"): - vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) - T.reads( - J_indptr[0 : m + 1], - J_indices[0:nnz], - A_data[0:nnz], - B_data[0 : n * k], - C_data[0 : m * k], - ) - T.writes(C_data[0 : m * k]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[vi * k + vk] = T.float32(0) - C_data[vi * k + vk] = ( - C_data[vi * k + vk] - + A_data[J_indptr[vi] + vj] * B_data[J_indices[J_indptr[vi] + vj] * k + vk] - ) - - -# @T.prim_func -# def lowered_csrmm_dense_iter( -# a: T.handle, -# b: T.handle, -# c: T.handle, -# indptr: T.handle, -# indices: T.handle, -# m: T.int32, -# n: T.int32, -# k: T.int32, -# nnz: T.int32, -# ) -> None: -# # function attr dict -# T.func_attr({"global_symbol": "main", "tir.noalias": True}) -# A_data = T.match_buffer(a, (nnz,), "float32") -# B_data = T.match_buffer(b, (n * k,), "float32") -# C_data = T.match_buffer(c, (m * k,), "float32") -# J_indptr = T.match_buffer(indptr, (m + 1,), "int32") -# J_indices = T.match_buffer(indices, (nnz,), "int32") -# # body -# # with T.block("root") -# for v_vi, v_vj, v_vk in T.grid(m, n, k): -# with T.block("csrmm0"): -# vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) -# T.reads( -# J_indptr[0 : m + 1], -# J_indices[0:nnz], -# A_data[0:nnz], -# B_data[0 : n * k], -# C_data[0 : m * k], -# ) -# T.writes(C_data[0 : m * k]) -# T.block_attr({"sparse": True}) -# with T.init(): -# C_data[vi * k + vk] = T.float32(0) -# C_data[vi * k + vk] = ( -# C_data[vi * k + vk] -# + A_data[ -# T.tvm_lower_bound( -# J_indices.data, vj, J_indptr[vi], J_indptr[vi + 1], dtype="int32" -# ) -# ] -# * B_data[vj * k + vk] -# ) - - -@T.prim_func -def lowered_csr_reduce( - a: T.handle, - b: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - m: T.int32, - nnz: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, [nnz], dtype="float32") - B_data = T.match_buffer(b, [n], dtype="float32") - J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - for v_vi in T.serial(0, n): - with T.block("csr_reduce_outer"): - vi = T.axis.spatial(n, v_vi) - T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]]) - T.writes([B_data[0:n]]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): - with T.block("csr_reduce"): - vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) - T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]]) - T.writes([B_data[0:n]]) - T.block_attr({"sparse": True}) - with T.init(): - B_data[vi] = T.float32(0) - B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] - - -@T.prim_func -def lowered_segment_reduce( - a: T.handle, b: T.handle, indptr: T.handle, n: T.int32, nnz: T.int32 -) -> None: - A_data = T.match_buffer(a, (nnz,), "float32") - B_data = T.match_buffer(b, (n,), "float32") - J_indptr = T.match_buffer(indptr, (n + 1,), "int32") - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for v_vi in T.serial(n): - with T.block("segment_reduce0"): - vi = T.axis.spatial(n, v_vi) - T.reads(J_indptr[0 : n + 1], A_data[0:nnz], B_data[0:n]) - T.writes(B_data[0:n]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): - with T.block("segment_reduce1"): - vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) - T.reads(J_indptr[0 : n + 1], A_data[0:nnz], B_data[0:n]) - T.writes(B_data[0:n]) - T.block_attr({"sparse": True}) - with T.init(): - B_data[vi] = T.float32(0) - B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] - - -@T.prim_func -def lowered_bsrmm( - a: T.handle, - b: T.handle, - c: T.handle, - j_indptr: T.handle, - j_indices: T.handle, - nb: T.int32, - mb: T.int32, - nnzb: T.int32, - blk: T.int32, - feat_size: T.int32, -) -> None: - A_data = T.match_buffer(a, (nnzb * blk * blk,), "float32") - B_data = T.match_buffer(b, (mb * blk * feat_size,), "float32") - C_data = T.match_buffer(c, (nb * blk * feat_size,), "float32") - J_indptr = T.match_buffer(j_indptr, (nb + 1,), "int32") - J_indices = T.match_buffer(j_indices, (nnzb,), "int32") - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for v_vi, v_vbi, v_vbj, v_vf in T.grid(nb, blk, blk, feat_size): - with T.block("bsrmm0"): - vi, vbi, vbj, vf = T.axis.remap("SSRS", [v_vi, v_vbi, v_vbj, v_vf]) - T.reads( - J_indptr[0 : nb + 1], - J_indices[0:nnzb], - A_data[0 : nnzb * blk * blk], - B_data[0 : mb * blk * feat_size], - C_data[0 : nb * blk * feat_size], - ) - T.writes(C_data[0 : nb * blk * feat_size]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): - with T.block("bsrmm1"): - vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) - T.reads( - J_indptr[0 : nb + 1], - J_indices[0:nnzb], - A_data[0 : nnzb * blk * blk], - B_data[0 : mb * blk * feat_size], - C_data[0 : nb * blk * feat_size], - ) - T.writes(C_data[0 : nb * blk * feat_size]) - T.block_attr({"sparse": True}) - C_data[(vi * blk + vbi) * feat_size + vf] = ( - C_data[(vi * blk + vbi) * feat_size + vf] - + A_data[((J_indptr[vi] + vj) * blk + vbi) * blk + vbj] - * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] - ) - - -@T.prim_func -def lowered_ellmm( - a: T.handle, - b: T.handle, - c: T.handle, - j_indices: T.handle, - nb: T.int32, - mb: T.int32, - feat_size: T.int32, - col: T.int32, - blk: T.int32, -) -> None: - A_data = T.match_buffer(a, (nb * col * blk * blk,), "float32") - B_data = T.match_buffer(b, (mb * blk * feat_size,), "float32") - C_data = T.match_buffer(c, (nb * blk * feat_size,), "float32") - J_indices = T.match_buffer(j_indices, (nb * col,), "int32") - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): - with T.block("ellmm0"): - vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) - T.reads( - J_indices[0 : nb * col], - A_data[0 : nb * col * blk * blk], - B_data[0 : mb * blk * feat_size], - C_data[0 : nb * blk * feat_size], - ) - T.writes(C_data[0 : nb * blk * feat_size]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - C_data[(vi * blk + vbi) * feat_size + vf] = ( - C_data[(vi * blk + vbi) * feat_size + vf] - + A_data[((vi * col + vj) * blk + vbi) * blk + vbj] - * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] - ) - - -@T.prim_func -def lowered_sddmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - m: T.int32, - n: T.int32, - k: T.int32, - nnz: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, (m * k,), "float32") - B_data = T.match_buffer(b, (n * k,), "float32") - C_data = T.match_buffer(c, (nnz,), "float32") - J_indptr = T.match_buffer(indptr, (m + 1,), "int32") - J_indices = T.match_buffer(indices, (nnz,), "int32") - for v_vi in T.serial(m): - with T.block("sddmm0"): - vi = T.axis.spatial(m, v_vi) - T.reads( - J_indptr[0 : m + 1], - J_indices[0:nnz], - A_data[0 : m * k], - B_data[0 : n * k], - C_data[0:nnz], - ) - T.writes(C_data[0:nnz]) - T.block_attr({"sparse": True}) - for v_vj, v_vk in T.grid(J_indptr[vi + 1] - J_indptr[vi], k): - with T.block("sddmm1"): - vj, vk = T.axis.remap("SR", [v_vj, v_vk]) - T.reads( - J_indptr[0 : m + 1], - J_indices[0:nnz], - A_data[0 : m * k], - B_data[0 : n * k], - C_data[0:nnz], - ) - T.writes(C_data[0:nnz]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[J_indptr[vi] + vj] = T.float32(0) - C_data[J_indptr[vi] + vj] = ( - C_data[J_indptr[vi] + vj] - + A_data[vi * k + vk] * B_data[J_indices[J_indptr[vi] + vj] * k + vk] - ) - - -# from tvm.script import tir as T -@T.prim_func -def lowered_sddmm_fuse( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - m: T.int32, - n: T.int32, - k: T.int32, - nnz: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, (m * k,), "float32") - B_data = T.match_buffer(b, (n * k,), "float32") - C_data = T.match_buffer(c, (nnz,), "float32") - J_indptr = T.match_buffer(indptr, (m + 1,), "int32") - J_indices = T.match_buffer(indices, (nnz,), "int32") - # body - # with T.block("root") - for v_vi, v_vj, v_vk in T.grid(1, nnz, k): - with T.block("sddmm0"): - vi, vj, vk = T.axis.remap("SSR", [v_vi, v_vj, v_vk]) - T.reads( - J_indptr[0 : m + 1], - J_indices[0:nnz], - A_data[0 : m * k], - B_data[0 : n * k], - C_data[0:nnz], - ) - T.writes(C_data[0:nnz]) - T.block_attr({"sparse": True}) - with T.init(): - C_data[vj] = T.float32(0) - C_data[vj] = ( - C_data[vj] - + A_data[ - (T.tvm_upper_bound(J_indptr.data, vj, 0, m + 1, dtype="int32") - 1) * k + vk - ] - * B_data[J_indices[vj] * k + vk] - ) - - -@T.prim_func -def lowered_bmm( - x: T.handle, - y: T.handle, - z: T.handle, - indptr_i: T.handle, - indptr_j: T.handle, - indptr_k: T.handle, - indptr_ij: T.handle, - indptr_jk: T.handle, - indptr_ik: T.handle, - batch_size: T.int32, - nnz_i: T.int32, - nnz_j: T.int32, - nnz_k: T.int32, - nnz_ij: T.int32, - nnz_jk: T.int32, - nnz_ik: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - X_data = T.match_buffer(x, (nnz_ij,), "float32") - Y_data = T.match_buffer(y, (nnz_jk,), "float32") - Z_data = T.match_buffer(z, (nnz_ik,), "float32") - I_indptr = T.match_buffer(indptr_i, (batch_size + 1,), "int32") - J_indptr = T.match_buffer(indptr_j, (batch_size + 1,), "int32") - K_indptr = T.match_buffer(indptr_k, (batch_size + 1,), "int32") - IJ_indptr = T.match_buffer(indptr_ij, (batch_size + 1,), "int32") - JK_indptr = T.match_buffer(indptr_jk, (batch_size + 1,), "int32") - IK_indptr = T.match_buffer(indptr_ik, (batch_size + 1,), "int32") - # body - # with T.block("root") - for v_vb in T.serial(batch_size): - with T.block("bmm0"): - vb = T.axis.spatial(batch_size, v_vb) - T.reads( - I_indptr[0 : batch_size + 1], - J_indptr[0 : batch_size + 1], - K_indptr[0 : batch_size + 1], - IJ_indptr[0 : batch_size + 1], - JK_indptr[0 : batch_size + 1], - IK_indptr[0 : batch_size + 1], - X_data[0:nnz_ij], - Y_data[0:nnz_jk], - Z_data[0:nnz_ik], - ) - T.writes(Z_data[0:nnz_ik]) - T.block_attr({"sparse": True}) - for v_vi, v_vj, v_vk in T.grid( - I_indptr[vb + 1] - I_indptr[vb], - J_indptr[vb + 1] - J_indptr[vb], - K_indptr[vb + 1] - K_indptr[vb], - ): - with T.block("bmm1"): - vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) - T.reads( - I_indptr[0 : batch_size + 1], - J_indptr[0 : batch_size + 1], - K_indptr[0 : batch_size + 1], - IJ_indptr[0 : batch_size + 1], - JK_indptr[0 : batch_size + 1], - IK_indptr[0 : batch_size + 1], - X_data[0:nnz_ij], - Y_data[0:nnz_jk], - Z_data[0:nnz_ik], - ) - T.writes(Z_data[0:nnz_ik]) - T.block_attr({"sparse": True}) - with T.init(): - Z_data[ - IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk - ] = T.float32(0) - Z_data[IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk] = ( - Z_data[IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk] - + X_data[IJ_indptr[vb] + vi * (J_indptr[vb + 1] - J_indptr[vb]) + vj] - * Y_data[JK_indptr[vb] + vj * (K_indptr[vb + 1] - K_indptr[vb]) + vk] - ) - - -@T.prim_func -def lowered_square_sum( - a: T.handle, - b: T.handle, - indptr_j: T.handle, - indices_j: T.handle, - indptr_k: T.handle, - indices_k: T.handle, - nnz_j: T.int32, - nnz_k: T.int32, - M: T.int32, - N1: T.int32, - N2: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, [nnz_k], dtype="float32") - B_data = T.match_buffer(b, [M], dtype="float32") - J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") - J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") - K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") - K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32") - - for v_vi in T.serial(0, M): - with T.block("square_sum_2"): - vi = T.axis.spatial(M, v_vi) - T.reads( - [ - J_indptr[0 : M + 1], - J_indices[0:nnz_j], - K_indptr[0 : nnz_j + 1], - K_indices[0:nnz_k], - A_data[0:nnz_k], - B_data[0:M], - ] - ) - T.writes([B_data[0:M]]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): - with T.block("square_sum_1"): - vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) - T.reads( - [ - J_indptr[0 : M + 1], - J_indices[0:nnz_j], - K_indptr[0 : nnz_j + 1], - K_indices[0:nnz_k], - A_data[0:nnz_k], - B_data[0:M], - ] - ) - T.writes([B_data[0:M]]) - T.block_attr({"sparse": True}) - with T.init(): - B_data[vi] = T.float32(0) - for v_vk in T.serial( - 0, K_indptr[J_indptr[vi] + vj + 1] - K_indptr[J_indptr[vi] + vj] - ): - with T.block("square_sum"): - vk = T.axis.reduce( - K_indptr[J_indptr[vi] + vj + 1] - K_indptr[J_indptr[vi] + vj], v_vk - ) - T.reads( - [ - J_indptr[0 : M + 1], - J_indices[0:nnz_j], - K_indptr[0 : nnz_j + 1], - K_indices[0:nnz_k], - A_data[0:nnz_k], - B_data[0:M], - ] - ) - T.writes([B_data[0:M]]) - T.block_attr({"sparse": True}) - B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk] - - -# @T.prim_func -# def lowered_square_sum_two_K( -# a: T.handle, -# b: T.handle, -# indptr_j: T.handle, -# indices_j: T.handle, -# indptr_k0: T.handle, -# indices_k0: T.handle, -# indptr_k1: T.handle, -# indices_k1: T.handle, -# nnz_j: T.int32, -# nnz_k: T.int32, -# M: T.int32, -# N1: T.int32, -# N2: T.int32, -# ) -> None: -# T.func_attr({"global_symbol": "main", "tir.noalias": True}) -# A_data = T.match_buffer(a, [nnz_k], dtype="float32") -# B_data = T.match_buffer(b, [M], dtype="float32") -# J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") -# J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") -# K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32") -# K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32") -# K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32") -# K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32") - -# for v_vi in T.serial(0, M): -# with T.block("square_sum_2"): -# vi = T.axis.spatial(M, v_vi) -# T.reads( -# [ -# J_indptr[0 : M + 1], -# J_indices[0:nnz_j], -# K0_indptr[0 : nnz_j + 1], -# K0_indices[0:nnz_k], -# K1_indptr[0 : nnz_j + 1], -# K1_indices[0:nnz_k], -# A_data[0:nnz_k], -# B_data[0:M], -# ] -# ) -# T.writes([B_data[0:M]]) -# T.block_attr({"sparse": True}) -# for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): -# with T.block("square_sum_1"): -# vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) -# T.reads( -# [ -# J_indptr[0 : M + 1], -# J_indices[0:nnz_j], -# K0_indptr[0 : nnz_j + 1], -# K0_indices[0:nnz_k], -# K1_indptr[0 : nnz_j + 1], -# K1_indices[0:nnz_k], -# A_data[0:nnz_k], -# B_data[0:M], -# ] -# ) -# T.writes([B_data[0:M]]) -# T.block_attr({"sparse": True}) -# with T.init(): -# B_data[vi] = T.float32(0) -# for v_vk in T.serial( -# 0, K1_indptr[J_indptr[vi] + vj + 1] - K1_indptr[J_indptr[vi] + vj] -# ): -# with T.block("square_sum"): -# vk = T.axis.reduce( -# K1_indptr[J_indptr[vi] + vj + 1] - K1_indptr[J_indptr[vi] + vj], -# v_vk, -# ) -# T.reads( -# [ -# J_indptr[0 : M + 1], -# J_indices[0:nnz_j], -# K0_indptr[0 : nnz_j + 1], -# K0_indices[0:nnz_k], -# K1_indptr[0 : nnz_j + 1], -# K1_indices[0:nnz_k], -# A_data[0:nnz_k], -# B_data[0:M], -# ] -# ) -# T.writes([B_data[0:M]]) -# T.block_attr({"sparse": True}) -# B_data[vi] = ( -# B_data[vi] -# + A_data[ -# T.tvm_lower_bound( -# K0_indices.data, -# K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], -# K0_indptr[J_indptr[vi] + vj], -# K0_indptr[J_indptr[vi] + vj + 1], -# dtype="int32", -# ) -# ] -# ) - - -@T.prim_func -def lowered_csr_element_wise( - a: T.handle, - b: T.handle, - indptr: T.handle, - indices: T.handle, - m: T.int32, - n: T.int32, - nnz: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_data = T.match_buffer(a, [nnz], dtype="float32") - B_data = T.match_buffer(b, [nnz], dtype="float32") - J_indptr = T.match_buffer(indptr, [m + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - for v_vi in T.serial(0, m): - with T.block("csr_element_wise_outer"): - vi = T.axis.spatial(m, v_vi) - T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]]) - T.writes([B_data[0:nnz]]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): - with T.block("csr_element_wise"): - vj = T.axis.spatial(J_indptr[vi + 1] - J_indptr[vi], v_vj) - T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]]) - T.writes([B_data[0:nnz]]) - T.block_attr({"sparse": True}) - B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) - - -@T.prim_func -def lowered_rgcn_forward( - etype: T.handle, - w: T.handle, - x: T.handle, - y: T.handle, - indptr: T.handle, - indices: T.handle, - n: T.int32, - r: T.int32, - feat_size: T.int32, - nnz: T.int32, -) -> None: - E_data = T.match_buffer(etype, [nnz], dtype="int32") - W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32") - X_data = T.match_buffer(x, [n * feat_size], dtype="float32") - Y_data = T.match_buffer(y, [n * feat_size], dtype="float32") - J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for v_vi, v_vout in T.grid(n, feat_size): - with T.block("rgcn-forward_0"): - vi, vout = T.axis.remap("SS", [v_vi, v_vout]) - T.reads( - J_indptr[0 : n + 1], - J_indices[0:nnz], - E_data[0:nnz], - W_data[0 : r * feat_size * feat_size], - X_data[0 : n * feat_size], - Y_data[0 : n * feat_size], - ) - T.writes(Y_data[0 : n * feat_size]) - T.block_attr({"sparse": True}) - for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): - for v_vin in T.serial(feat_size): - with T.block("rgcn-forward_1"): - vj, vin = T.axis.remap("RR", [v_vj, v_vin]) - T.reads( - J_indptr[0 : n + 1], - J_indices[0:nnz], - E_data[0:nnz], - W_data[0 : r * feat_size * feat_size], - X_data[0 : n * feat_size], - Y_data[0 : n * feat_size], - ) - T.writes(Y_data[0 : n * feat_size]) - T.block_attr({"sparse": True}) - with T.init(): - Y_data[vi * feat_size + vout] = T.float32(0) - Y_data[vi * feat_size + vout] = ( - Y_data[vi * feat_size + vout] - + W_data[ - (E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin - ] - * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin] - ) - - -@T.prim_func -def lowered_fused_reduction_4d_2d( - x: T.handle, - y: T.handle, - indptr_j: T.handle, - indptr_k: T.handle, - indptr_l: T.handle, - n: T.int32, - nnz_j: T.int32, - nnz_k: T.int32, - nnz_l: T.int32, -) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - X_data = T.match_buffer(x, [nnz_l], dtype="float32") - Y_data = T.match_buffer(y, [nnz_j], dtype="float32") - J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32") - K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") - L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32") - # body - # with T.block("root") - for v_vi, v_vj in T.grid(1, nnz_j): - with T.block("reduction_4d_2d0"): - vi, vj = T.axis.remap("SS", [v_vi, v_vj]) - T.reads( - J_indptr[0 : n + 1], - K_indptr[0 : nnz_j + 1], - L_indptr[0 : nnz_k + 1], - X_data[0:nnz_l], - Y_data[0:nnz_j], - ) - T.writes(Y_data[0:nnz_j]) - T.block_attr({"sparse": True}) - for v_vk in T.serial(K_indptr[vj + 1] - K_indptr[vj]): - with T.block("reduction_4d_2d1"): - vk = T.axis.reduce(K_indptr[vj + 1] - K_indptr[vj], v_vk) - T.reads( - J_indptr[0 : n + 1], - K_indptr[0 : nnz_j + 1], - L_indptr[0 : nnz_k + 1], - X_data[0:nnz_l], - Y_data[0:nnz_j], - ) - T.writes(Y_data[0:nnz_j]) - T.block_attr({"sparse": True}) - with T.init(): - Y_data[vj] = T.float32(0) - for v_vl in T.serial( - L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk] - ): - with T.block("reduction_4d_2d2"): - vl = T.axis.reduce( - L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk], v_vl - ) - T.reads( - J_indptr[0 : n + 1], - K_indptr[0 : nnz_j + 1], - L_indptr[0 : nnz_k + 1], - X_data[0:nnz_l], - Y_data[0:nnz_j], - ) - T.writes(Y_data[0:nnz_j]) - T.block_attr({"sparse": True}) - Y_data[vj] = Y_data[vj] + X_data[L_indptr[K_indptr[vj] + vk] + vl] - - -@T.prim_func -def lowered_fused_reduction_4d_3d( - x: T.handle, - y: T.handle, - indptr_j: T.handle, - indptr_k: T.handle, - indptr_l: T.handle, - n: T.int32, - nnz_j: T.int32, - nnz_k: T.int32, - nnz_l: T.int32, -) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - X_data = T.match_buffer(x, [nnz_l], dtype="float32") - Y_data = T.match_buffer(y, [nnz_k], dtype="float32") - J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32") - K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") - L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32") - # body - # with T.block("root") - for v_vi, v_vj, v_vk in T.grid(1, 1, nnz_k): - with T.block("reduction_4d_3d0"): - vi, vj, vk = T.axis.remap("SSS", [v_vi, v_vj, v_vk]) - T.reads( - J_indptr[0 : n + 1], - K_indptr[0 : nnz_j + 1], - L_indptr[0 : nnz_k + 1], - X_data[0:nnz_l], - Y_data[0:nnz_k], - ) - T.writes(Y_data[0:nnz_k]) - T.block_attr({"sparse": True}) - for v_vl in T.serial(L_indptr[vk + 1] - L_indptr[vk]): - with T.block("reduction_4d_3d1"): - vl = T.axis.reduce(L_indptr[vk + 1] - L_indptr[vk], v_vl) - T.reads( - J_indptr[0 : n + 1], - K_indptr[0 : nnz_j + 1], - L_indptr[0 : nnz_k + 1], - X_data[0:nnz_l], - Y_data[0:nnz_k], - ) - T.writes(Y_data[0:nnz_k]) - T.block_attr({"sparse": True}) - with T.init(): - Y_data[vk] = T.float32(0) - Y_data[vk] = Y_data[vk] + X_data[L_indptr[vk] + vl] diff --git a/tests/python/sparsetir/rgcn_two_stage_lowering.py b/tests/python/sparsetir/rgcn_two_stage_lowering.py deleted file mode 100644 index 74ebe2ab6..000000000 --- a/tests/python/sparsetir/rgcn_two_stage_lowering.py +++ /dev/null @@ -1,33 +0,0 @@ -from tvm.script import tir as T -from sparse_tir_scripts import rgcn_hetero_forward -import tvm - - -def test_schedule_rgcn(): - func = rgcn_hetero_forward - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.LowerSparseIter()(mod) - sch = tvm.tir.Schedule(mod) - print(sch.mod["main"].script()) - - blk0 = sch.get_block("rgcn-hetero-forward0") - blk1 = sch.get_block("rgcn-hetero-forward1") - blk2 = sch.get_block("rgcn-hetero-forward2") - read_blk = sch.cache_read(blk1, 2, "shared") - write_blk = sch.cache_write(blk2, 0, "local") - f_out, r = sch.get_loops(blk0) - (i,) = sch.get_loops(blk1) - j, f_in = sch.get_loops(blk2) - sch.bind(f_in, "threadIdx.x") - sch.reorder(f_in, j) - # sch.decompose_reduction(blk2, f_in) - i1, i2 = sch.split(i, [None, 8]) - sch.bind(i2, "blockIdx.x") - sch.bind(r, "blockIdx.y") - sch.bind(f_out, "threadIdx.y") - mod = tvm.tir.transform.LowerSparseBuffer()(sch.mod) - print(mod["main"].script()) - - -if __name__ == "__main__": - test_schedule_rgcn() diff --git a/tests/python/sparsetir/test_butterfly.py b/tests/python/sparsetir/test_butterfly.py deleted file mode 100644 index 67c7f8639..000000000 --- a/tests/python/sparsetir/test_butterfly.py +++ /dev/null @@ -1,38 +0,0 @@ -import tvm -import tvm.testing -from tvm.runtime.ndarray import device -import tvm.tir as tir -import scipy.sparse as sp -import numpy as np -from tvm.script import tir as T - - -@T.prim_func -def butterfly(w1: T.handle, w2: T.handle, w3: T.handle, w4: T.handle, x: T.handle, y: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - W1 = T.match_buffer(w1, (16, 2), "float32") - W2 = T.match_buffer(w2, (16, 2), "float32") - W3 = T.match_buffer(w3, (16, 2), "float32") - W4 = T.match_buffer(w4, (16, 2), "float32") - X = T.match_buffer(x, (16, 64), "float32") - Y = T.match_buffer(y, (16, 64), "float32") - - for i, j, k in T.grid(16, 2, 64): - with T.block("wx"): - vi, vj, vk = T.axis.remap("SRS", [i, j, k]) - with T.init(): - Y[vi, vk] = 0. - Y[vi, vk] = Y[vi, vk] +\ - W1[vi, vj] * X[vj * 8 + T.floormod(vi, 8), vk] +\ - W2[vi, vj] * X[T.floordiv(vi, 8) * 8 + vj * 4 + T.floormod(vi, 4), vk] +\ - W3[vi, vj] * X[T.floordiv(vi, 4) * 4 + vj * 2 + T.floormod(vi, 2), vk] +\ - W4[vi, vj] * X[T.floordiv(vi, 2) * 2 + vj, vk] - - -def test_butterfly(): - sch = tir.Schedule(butterfly) - print(sch.mod["main"].script()) - - -if __name__ == "__main__": - test_butterfly() diff --git a/tests/python/sparsetir/test_ellmm_tensorize.py b/tests/python/sparsetir/test_ellmm_tensorize.py deleted file mode 100644 index 5635628d8..000000000 --- a/tests/python/sparsetir/test_ellmm_tensorize.py +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import tvm -import tvm.tir as tir -from tvm.script import tir as T -from sparse_tir_scripts import ellmm -from tvm.sparse import lower_sparse_iter, lower_sparse_buffer - - -@T.prim_func -def wmma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), align=128, offset_factor=1, scope="local") - B = T.match_buffer(b, (16, 16), align=128, offset_factor=1, scope="local") - C = T.match_buffer(c, (16, 16), align=128, offset_factor=1, scope="local") - - with T.block("root"): - T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) - T.writes(C[0 : 16, 0 : 16]) - for i, k, j in T.grid(16, 16, 16): - with T.block("update"): - vi, vk, vj = T.axis.remap("SRS", [i, k, j]) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - - -@T.prim_func -def wmma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), align=128, offset_factor=1, scope="local") - B = T.match_buffer(b, (16, 16), align=128, offset_factor=1, scope="local") - C = T.match_buffer(c, (16, 16), align=128, offset_factor=1, scope="local") - - with T.block("root"): - T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) - T.writes(C[0 : 16, 0 : 16]) - T.evaluate( - T.tvm_mma_sync( - C.data, - C.elem_offset // 256, - A.data, - A.elem_offset // 256, - B.data, - B.elem_offset // 256, - C.data, - C.elem_offset // 256, - dtype="handle", - ) - ) - -tir.TensorIntrin.register("wmma_intrin", wmma_desc, wmma_intrin) - - -def test_blocked_ellmm_tensorize(): - NB, MB, FEAT_SIZE, COL, BLK = ellmm.params[-5:] - mod = tvm.IRModule.from_expr( - ellmm.specialize({NB: 32, MB: 32, FEAT_SIZE: 128, COL: 2, BLK: 16}) - ) - mod = lower_sparse_iter(mod) - sch = tvm.tir.Schedule(mod) - blk = sch.get_block("ellmm0") - i, j, bi, bj, f = sch.get_loops(blk) - fo, fi = sch.split(f, [None, 16]) - sch.reorder(i, j, fo, bi, bj, fi) - blk_inner = sch.blockize(bi) - blk, blk_inner = blk_inner, blk - A_local = sch.cache_read(blk_inner, 1, "local") - B_local = sch.cache_read(blk_inner, 2, "local") - C_local = sch.cache_write(blk_inner, 0, "local") - sch.hide_buffer_access(blk_inner, "read", [3]) - sch.tensorize(bi, "wmma_intrin") - print(sch.mod["main"].script()) - - -if __name__ == "__main__": - test_blocked_ellmm_tensorize() diff --git a/tests/python/sparsetir/test_tir_sparse_atomic.py b/tests/python/sparsetir/test_tir_sparse_atomic.py deleted file mode 100644 index 44add658f..000000000 --- a/tests/python/sparsetir/test_tir_sparse_atomic.py +++ /dev/null @@ -1,185 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from dgl.heterograph import DGLHeteroGraph -import tvm -import tvm.testing -import tvm.tir as tir -import scipy.sparse as sp -import numpy as np -import dgl -import dgl.function as fn -import torch as th -from tvm.script import tir as T -from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset -from sparse_tir_scripts import rgcn_hetero_forward -from tvm.sparse import lower_sparse_iter, lower_sparse_buffer -from typing import List, Tuple - - -def get_dataset_by_name(name: str): - if name == "aifb": - return AIFBDataset() - elif name == "mutag": - return MUTAGDataset() - elif name == "bgs": - return BGSDataset() - elif name == "am": - return AMDataset() - else: - raise KeyError("Unknown dataset {}.".format(name)) - - -class TorchOpTimer(object): - def __enter__(self): - self.start_event = th.cuda.Event(enable_timing=True) - self.end_event = th.cuda.Event(enable_timing=True) - self.start_event.record() - return self - - def __exit__(self, type, value, traceback): - self.end_event.record() - th.cuda.synchronize() # Wait for the events to be recorded! - self.time = self.start_event.elapsed_time(self.end_event) - - -def prepare_hetero_graph_simplified(g: dgl.DGLHeteroGraph): - ntype_pointer = np.cumsum([0] + [g.number_of_nodes(ntype) for ntype in g.ntypes]) - - etype_pointer = [0] - for etype in g.canonical_etypes: - g_sub = g[etype] - etype_pointer.append(etype_pointer[-1] + g_sub.num_edges()) - - return { - "ntype_node_pointer": th.IntTensor(ntype_pointer), - "etype_edge_pointer": th.IntTensor(etype_pointer), - } - - -blks = ["blockIdx.x", "blockIdx.y", "blockIdx.z"] - -def test_lower_rgcn_hetero( - g: dgl.DGLHeteroGraph, - feat_size: int, - blk_order: List[Tuple[str]], - split_factor_f: int, - split_factor_i: int, -): - N, R, FEAT_SIZE, NNZ_I, NNZ_J = rgcn_hetero_forward.params[-5:] - n = g.num_nodes() - r = len(g.etypes) - nnz_j = g.num_edges() - - feat = th.rand(n, feat_size).to(0) / 100 - out = th.zeros(n, feat_size).to(0) / 100 - weight = th.rand(r, feat_size, feat_size).to(0) - W = tvm.nd.array(weight.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) - X = tvm.nd.array(feat.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) - Y = tvm.nd.array(out.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) - - indptr_i = [th.LongTensor([0])] - indices_i = [] - indptr_j = [th.LongTensor([0])] - indices_j = [] - for etype in g.canonical_etypes: - src_type, _, dst_type = etype - etype_id = g.get_etype_id(etype) - src_type_id = g.get_ntype_id(src_type) - dst_type_id = g.get_ntype_id(dst_type) - g_sub = g[etype] - indptr, indices, _ = g_sub.adj_sparse(fmt="csc") - - unique_nodes = th.nonzero(indptr[:-1] != indptr[1:]).squeeze(1) - indptr_i.append(th.LongTensor([len(unique_nodes)])) - indices_i.append(unique_nodes + g.ntype_pointer[dst_type_id]) - indptr_j.append(indptr[unique_nodes] + g.etype_pointer[etype_id]) - indices_j.append(indices + g.ntype_pointer[src_type_id]) - - indptr_i = tvm.nd.array(th.cat(indptr_i).numpy().astype("int32"), device=tvm.cuda(0)) - indices_i = tvm.nd.array(th.cat(indices_i).numpy().astype("int32"), device=tvm.cuda(0)) - indptr_j = tvm.nd.array(th.cat(indptr_j).numpy().astype("int32"), device=tvm.cuda(0)) - indices_j = tvm.nd.array(th.cat(indices_j).numpy().astype("int32"), device=tvm.cuda(0)) - - nnz_i = indices_i.shape[0] - mod = tvm.IRModule.from_expr( - rgcn_hetero_forward.specialize( - {N: n, R: r, FEAT_SIZE: feat_size, NNZ_I: nnz_i, NNZ_J: nnz_j} - ) - ) - mod = lower_sparse_iter(mod) - sch = tir.Schedule(mod) - - blk0 = sch.get_block("rgcn-hetero-forward0") - blk1 = sch.get_block("rgcn-hetero-forward1") - blk2 = sch.get_block("rgcn-hetero-forward2") - read_blk = sch.cache_read(blk1, 2, "shared") - write_blk = sch.cache_write(blk2, 0, "local") - sch.annotate(write_blk, "atomic", True) - f_out, r = sch.get_loops(blk0) - f_out_o, f_out_i = sch.split(f_out, [split_factor_f, None]) - (i,) = sch.get_loops(blk1) - j, f_in = sch.get_loops(blk2) - sch.reorder(f_in, j) - i1, i2 = sch.split(i, [None, split_factor_i]) - sch.lift_loop(i2) - sch.bind(i2, blks[blk_order[0]]) - sch.bind(r, blks[blk_order[1]]) - sch.bind(f_out_o, blks[blk_order[2]]) - sch.bind(f_in, "threadIdx.x") - sch.bind(f_out_i, "threadIdx.y") - _, _, ax2 = sch.get_loops(read_blk) - sch.bind(ax2, "threadIdx.x") - mod = lower_sparse_buffer(sch.mod) - - f = tvm.build(mod["main"], target="cuda") - print(f.imported_modules[0].get_source()) - assert False - - # cold_start = 3 - # total = 10 - # accum = 0 - - # for epoch in range(10): - # with TorchOpTimer() as timer: - # f(W, X, Y, indptr_i, indices_i, indptr_j, indices_j) - # if epoch >= cold_start: - # accum += timer.time - - # print("sparse-tir:\t\t {}ms".format(accum / (total - cold_start))) - - -if __name__ == "__main__": - for feat_size in [32]: # [4, 8, 16, 32, 64]: - for name in ["am"]: # ['aifb', 'mutag', 'bgs', 'am']: - dataset = get_dataset_by_name(name) - g = dataset[0] - type_pointers = prepare_hetero_graph_simplified(g) - g.ntype_pointer = type_pointers["ntype_node_pointer"] - g.etype_pointer = type_pointers["etype_edge_pointer"] - for blk_order in [(2, 0, 1)]: - for split_factor_f in [8]: - for split_factor_i in [512]: - print( - "dataset {}, blk_order {}, split_factor_f {}, split_factor_i {}:".format( - name, blk_order, split_factor_f, split_factor_i - ) - ) - test_lower_rgcn_hetero( - g, feat_size, blk_order, split_factor_f, split_factor_i - ) diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py deleted file mode 100644 index 3a7906451..000000000 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ /dev/null @@ -1,119 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -import tvm.testing -import pytest -from lowered_tir import * -from sparse_tir_scripts import * - - -def test_csrmm(): - mod = tvm.IRModule.from_expr(csrmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) - - -def test_csrmm_dense_iter(): - mod = tvm.IRModule.from_expr(csrmm_dense_iter) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm_dense_iter, True) - - -def test_segment_reduce(): - mod = tvm.IRModule.from_expr(segment_reduce) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_segment_reduce, True) - - -def test_csr_reduce(): - mod = tvm.IRModule.from_expr(csr_reduce) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True) - - -def test_bsrmm(): - mod = tvm.IRModule.from_expr(bsrmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True) - - -def test_ellpack_mm(): - mod = tvm.IRModule.from_expr(ellmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_ellmm, True) - - -def test_csr_element_wise(): - mod = tvm.IRModule.from_expr(csr_element_wise) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) - - -def test_bmm(): - mod = tvm.IRModule.from_expr(bmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_bmm) - - -def test_sddmm(): - mod = tvm.IRModule.from_expr(sddmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm) - - -def test_fused_sddmm(): - mod = tvm.IRModule.from_expr(fused_sddmm) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm_fuse) - - -def test_square_sum(): - mod = tvm.IRModule.from_expr(square_sum) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum, True) - - -def test_square_sum_two_K(): - mod = tvm.IRModule.from_expr(square_sum_two_K) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True) - - -def test_fused_reduction(): - mod = tvm.IRModule.from_expr(fused_reduction_4d_2d) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_2d, True) - - mod = tvm.IRModule.from_expr(fused_reduction_4d_3d) - mod = tvm.tir.transform.LowerSparseTIR()(mod) - tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_3d, True) - - -if __name__ == "__main__": - test_csrmm() - test_csrmm_dense_iter() - test_segment_reduce() - test_csr_reduce() - test_bsrmm() - test_ellpack_mm() - test_csr_element_wise() - test_sddmm() - test_fused_sddmm() - test_bmm() - test_square_sum() - test_square_sum_two_K() - test_fused_reduction() diff --git a/tests/python/sparsetir/test_tir_sparse_tensorize.py b/tests/python/sparsetir/test_tir_sparse_tensorize.py deleted file mode 100644 index 3c408c617..000000000 --- a/tests/python/sparsetir/test_tir_sparse_tensorize.py +++ /dev/null @@ -1,394 +0,0 @@ -import tvm -from tvm import tir -from tvm.script import tir as T -import tvm.testing -import numpy as np -import scipy.sparse as sp -from tvm.ir import IRModule -from tqdm import tqdm -from tvm.sparse import lower_sparse_iter, lower_sparse_buffer - - -@T.prim_func -def bsrmm( - a: T.handle, - b: T.handle, - c: T.handle, - indptr: T.handle, - indices: T.handle, - nb: T.int32, - mb: T.int32, - nnzb: T.int32, - blk: T.int32, - feat_size: T.int32, -) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2}) - I = T.dense_fixed(nb) - J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") - J_detach = T.dense_fixed(mb) - BI = T.dense_fixed(blk) - BJ = T.dense_fixed(blk) - F = T.dense_fixed(feat_size) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float16") - B = T.match_sparse_buffer(b, (J_detach, BJ, F), "float16") - C = T.match_sparse_buffer(c, (I, BI, F), "float32") - - with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [ - i, - bi, - bj, - f, - j, - ]: - with T.init(): - C[i, bi, f] = 0.0 - C[i, bi, f] = C[i, bi, f] + T.float32(A[i, j, bi, bj]) * T.float32(B[j, bj, f]) - - -@T.prim_func -def wmma_sync_desc(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None: - A_frag = T.match_buffer( - a_frag, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a" - ) - B_frag = T.match_buffer( - b_frag, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b" - ) - C_frag = T.match_buffer( - c_frag, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator" - ) - - with T.block("root"): - for i, j, k in T.grid(16, 16, 16): - with T.block("update"): - vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) - T.block_attr({"sparse": True}) - C_frag[vii, vjj] = C_frag[vii, vjj] + T.cast(A_frag[vii, vkk], "float32") * T.cast( - B_frag[vkk, vjj], "float32" - ) - - -@T.prim_func -def wmma_sync_impl(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None: - A_frag = T.match_buffer( - a_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a" - ) - B_frag = T.match_buffer( - b_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b" - ) - C_frag = T.match_buffer( - c_frag, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) - - with T.block("root"): - T.reads( - [ - C_frag[0:16, 0:16], - A_frag[0:16, 0:16], - B_frag[0:16, 0:16], - ] - ) - T.writes(C_frag[0:16, 0:16]) - for tx in T.thread_binding(0, 32, "threadIdx.x"): - T.evaluate( - T.tvm_mma_sync( - C_frag.data, - C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), - A_frag.data, - A_frag.elem_offset // 256 + T.floordiv(T.floormod(A_frag.elem_offset, 256), 16), - B_frag.data, - B_frag.elem_offset // 256 + T.floordiv(T.floormod(B_frag.elem_offset, 256), 16), - C_frag.data, - C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), - dtype="handle", - ) - ) - - -@T.prim_func -def wmma_load_a_desc(a: T.handle, a_frag: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="global") - A_frag = T.match_buffer( - a_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a" - ) - - with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(A_frag[0:16, 0:16]) - for i, j in T.grid(16, 16): - with T.block("load"): - vii, vjj = T.axis.remap("SS", [i, j]) - A_frag[vii, vjj] = A[vii, vjj] - - -@T.prim_func -def wmma_load_a_impl(a: T.handle, a_frag: T.handle) -> None: - s0 = T.var("int32") - s1 = T.var("int32") - A = T.match_buffer( - a, (16, 16), "float16", align=128, offset_factor=16, scope="global", strides=[s0, s1] - ) - A_frag = T.match_buffer( - a_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a" - ) - - with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(A_frag[0:16, 0:16]) - for tx in T.thread_binding(0, 32, "threadIdx.x"): - T.evaluate( - T.tvm_load_matrix_sync( - A_frag.data, - 16, - 16, - 16, - A_frag.elem_offset // 256 + T.floordiv(T.floormod(A_frag.elem_offset, 256), 16), - A.access_ptr("r"), - A.strides[0], - "row_major", - dtype="handle", - ) - ) - - -@T.prim_func -def wmma_load_b_desc(b: T.handle, b_frag: T.handle) -> None: - B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="global") - B_frag = T.match_buffer( - b_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b" - ) - with T.block("root"): - for i, j in T.grid(16, 16): - with T.block("load"): - vii, vjj = T.axis.remap("SS", [i, j]) - B_frag[vii, vjj] = B[vii, vjj] - - -@T.prim_func -def wmma_load_b_impl(b: T.handle, b_frag: T.handle) -> None: - s0 = T.var("int32") - s1 = T.var("int32") - B = T.match_buffer( - b, (16, 16), "float16", align=128, offset_factor=16, scope="global", strides=[s0, s1] - ) - B_frag = T.match_buffer( - b_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b" - ) - with T.block("root"): - T.reads(B[0:16, 0:16]) - T.writes(B_frag[0:16, 0:16]) - for tx in T.thread_binding(0, 32, "threadIdx.x"): - T.evaluate( - T.tvm_load_matrix_sync( - B_frag.data, - 16, - 16, - 16, - B_frag.elem_offset // 256 + T.floordiv(T.floormod(B_frag.elem_offset, 256), 16), - B.access_ptr("r"), - B.strides[0], - "row_major", - dtype="handle", - ) - ) - - -@T.prim_func -def wmma_fill_desc(c_frag: T.handle) -> None: - C_frag = T.match_buffer( - c_frag, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) - with T.block("root"): - for i, j in T.grid(16, 16): - with T.block("init"): - vii, vjj = T.axis.remap("SS", [i, j]) - C_frag[vii, vjj] = T.float32(0) - - -@T.prim_func -def wmma_fill_impl(c_frag: T.handle) -> None: - C_frag = T.match_buffer( - c_frag, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) - with T.block("root"): - T.reads([]) - T.writes(C_frag[0:16, 0:16]) - for tx in T.thread_binding(0, 32, "threadIdx.x"): - T.evaluate( - T.tvm_fill_fragment( - C_frag.data, - 16, - 16, - 16, - C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), - T.float32(0), - dtype="handle", - ) - ) - - -@T.prim_func -def wmma_store_desc(c_frag: T.handle, c: T.handle) -> None: - C_frag = T.match_buffer( - c_frag, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) - C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global") - with T.block("root"): - for i, j in T.grid(16, 16): - with T.block("store"): - vii, vjj = T.axis.remap("SS", [i, j]) - C[vii, vjj] = C_frag[vii, vjj] - - -@T.prim_func -def wmma_store_impl(c_frag: T.handle, c: T.handle) -> None: - s0 = T.var("int32") - s1 = T.var("int32") - C_frag = T.match_buffer( - c_frag, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) - C = T.match_buffer( - c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s0, s1] - ) - with T.block("root"): - T.reads(C_frag[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for tx in T.thread_binding(0, 32, "threadIdx.x"): - T.evaluate( - T.tvm_store_matrix_sync( - C_frag.data, - 16, - 16, - 16, - C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), - C.access_ptr("w"), - C.strides[0], - "row_major", - dtype="handle", - ) - ) - - -WMMA_SYNC = tir.TensorIntrin.register( - "wmma_sync", - wmma_sync_desc, - wmma_sync_impl, -) - -WMMA_LOAD_A = tir.TensorIntrin.register( - "wmma_load_a", - wmma_load_a_desc, - wmma_load_a_impl, -) - -WMMA_LOAD_B = tir.TensorIntrin.register( - "wmma_load_b", - wmma_load_b_desc, - wmma_load_b_impl, -) - -WMMA_FILL = tir.TensorIntrin.register( - "wmma_fill", - wmma_fill_desc, - wmma_fill_impl, -) - - -WMMA_STORE = tir.TensorIntrin.register( - "wmma_store", - wmma_store_desc, - wmma_store_impl, -) - - -block_size = 16 -nb = 32 -mb = 32 -feat_size = 256 -n = nb * block_size -m = mb * block_size - -A_block = sp.random(mb, nb, dtype="float32", density=0.05, format="csr", random_state=0) -indptr = A_block.indptr -indices = A_block.indices -nnzb = A_block.nnz -np.random.seed(0) -data = np.random.rand(nnzb, block_size, block_size) -A = sp.bsr_matrix((data.astype("float16"), indices, indptr), shape=(n, m)) -x = np.random.rand(m, feat_size).astype("float16") -y_ground_truth = A * x -y = np.zeros((n * feat_size,)).astype("float32") - -v_nb, v_mb, v_nnzb, v_blk, v_feat_size = bsrmm.params[-5:] -bsrmm = bsrmm.specialize( - {v_nb: nb, v_mb: mb, v_nnzb: nnzb, v_blk: block_size, v_feat_size: feat_size} -) -sch = tvm.tir.Schedule(bsrmm) -sp_iteration = sch.get_sparse_iteration("bsrmm") -i, bi, bj, f, j = sch.get_sp_iters(sp_iteration) -sch.sparse_reorder(sp_iteration, [i, j, bi, f, bj]) -mod = lower_sparse_iter(sch.mod) -sch = tir.Schedule(mod) -blk_inner = sch.get_block("bsrmm1") -blk_outer = sch.get_block("bsrmm0") -j, bi, f, bj = sch.get_loops(blk_inner) -fo, fi = sch.split(f, [None, 16]) -sch.reorder(fo, j, bi, fi, bj) -(i,) = sch.get_loops(blk_outer) -sch.bind(i, "blockIdx.x") -sch.bind(fo, "blockIdx.y") -# sch.lift_loop(fo) -new_blk = sch.blockize(bi) -C_local = sch.cache_write(new_blk, 0, "wmma.accumulator") -sch.reverse_compute_at(C_local, fo, True) -sch.decompose_reduction(new_blk, j) -A_local = sch.cache_read(blk_inner, 1, "wmma.matrix_a") -B_local = sch.cache_read(blk_inner, 2, "wmma.matrix_b") -sch.hide_buffer_access(blk_inner, "read", [3]) -sch.tensorize(sch.get_loops(blk_inner)[-3], "wmma_sync") -sch.tensorize(sch.get_loops(B_local)[-2], "wmma_load_b") -sch.tensorize(sch.get_loops(A_local)[-2], "wmma_load_a") -sch.tensorize(sch.get_loops(C_local)[-2], "wmma_store") -sch.tensorize(sch.get_loops(sch.get_block("bsrmm1_init"))[-2], "wmma_fill") -mod = lower_sparse_buffer(sch.mod) -print(mod["main"].script()) - - -# for t in tqdm(range(0, 2)): -# f = tvm.build(mod["main"], target="cuda") -# ctx = tvm.cuda(0) -# A_indptr = tvm.nd.array(np.copy(indptr).astype("int32"), device=ctx) -# A_indices = tvm.nd.array(np.copy(indices).astype("int32"), device=ctx) -# A_data = tvm.nd.array(np.copy(data).astype("float16"), device=ctx) -# X_nd = tvm.nd.array(np.copy(x.reshape(mb, block_size, feat_size)).astype("float16"), device=ctx) -# Y_nd = tvm.nd.array(np.zeros((nb, block_size, feat_size), dtype="float32"), device=ctx) -# f(A_data, X_nd, Y_nd, A_indptr, A_indices) -# tvm.testing.assert_allclose( -# np.copy(y_ground_truth).reshape(nb, block_size, feat_size), -# Y_nd.numpy(), -# rtol=1e-5, -# atol=1e-5, -# ) -# evaluator = f.time_evaluator(f.entry_name, ctx, number=10) -# print("w/o Tensor Cores:") -# print(evaluator(A_data, X_nd, Y_nd, A_indptr, A_indices)) - -for t in tqdm(range(0, 2)): - f = tvm.build(mod["main"], target="cuda") - print(f.imported_modules[0].get_source()) - ctx = tvm.cuda(0) - A_indptr = tvm.nd.array(np.copy(indptr).astype("int32"), device=ctx) - A_indices = tvm.nd.array(np.copy(indices).astype("int32"), device=ctx) - A_data = tvm.nd.array(np.copy(data).reshape(-1).astype("float16"), device=ctx) - X_nd = tvm.nd.array(np.copy(x.reshape(-1)).astype("float16"), device=ctx) - Y_nd = tvm.nd.array(np.zeros((nb * block_size * feat_size), dtype="float32"), device=ctx) - print(A_data) - f(A_data, X_nd, Y_nd, A_indptr, A_indices) - tvm.testing.assert_allclose( - np.copy(y_ground_truth).reshape(-1), - Y_nd.numpy(), - rtol=1e-5, - atol=1e-5, - ) -print("with Tensor Cores:") -evaluator = f.time_evaluator(f.entry_name, ctx, number=10) -print(evaluator(A_data, X_nd, Y_nd, A_indptr, A_indices))