From 0fc047c98b1ebf730b8c9aad8b94ddac28a7b34b Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 5 Jul 2024 11:45:12 +0800 Subject: [PATCH] [Compute-inline] Prefer T.where for reverse compute-inlined block with predicate (#17128) * prefer T.where for reverse compute-inlined block with predicate * update ut scripts --------- Co-authored-by: wrongtest --- src/tir/schedule/primitive/compute_inline.cc | 44 ++++++++------ tests/python/dlight/test_gpu_matmul.py | 20 +++---- .../dlight/test_gpu_matmul_tensorize.py | 20 +++---- ...test_meta_schedule_schedule_rule_mlt_tc.py | 4 +- .../test_tir_schedule_compute_inline.py | 59 ++++++++++++++++--- 5 files changed, 98 insertions(+), 49 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index d6be0e5805dd..df74497b4a69 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -682,11 +682,14 @@ class ReverseComputeInliner : public BaseInliner { using BaseInliner::VisitStmt_; /*! \brief Generate the predicate after inlining based on the consumer predicate */ - Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) { + BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) { // Bind the producer block iter domains for simplification Map subst_map; + Block producer_block = producer_block_realize->block; for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { const IterVar& iter = producer_block->iter_vars[i]; + const PrimExpr& binding = producer_block_realize->iter_values[i]; + subst_map.Set(iter->var, binding); analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); } if (producer_block->annotations.count(tir::attr::auto_copy) != 0) { @@ -705,30 +708,33 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_); // Simplify the predicate using the producer block iter domains predicate = analyzer_.Simplify(predicate); - ObjectPtr block = make_object(*producer_block); if (is_one(predicate)) { - return Block(block); - } - if (const auto* if_ = producer_block->body.as()) { - PrimExpr if_predicate = analyzer_.Simplify(if_->condition); - if (!StructuralEqual()(predicate, if_predicate)) { - predicate = analyzer_.Simplify(predicate && if_->condition); + return producer_block_realize; + } + if (const auto* if_ = producer_block->body.as()) { + if (!if_->else_case.defined()) { + PrimExpr if_predicate = analyzer_.Simplify(if_->condition); + if (!StructuralEqual()(predicate, if_predicate)) { + predicate = analyzer_.Simplify(predicate && if_->condition); + producer_block.CopyOnWrite()->body = if_->then_case; + } } - block->body = IfThenElse(predicate, if_->then_case); - return Block(block); } - block->body = IfThenElse(predicate, block->body); - return Block(block); + PrimExpr outer_predicate = Substitute(predicate, subst_map); + auto n = producer_block_realize.CopyOnWrite(); + n->block = producer_block; + n->predicate = analyzer_.Simplify(outer_predicate); + return GetRef(n); } - Stmt VisitStmt_(const BlockNode* op) final { - Block src_block = GetRef(op); - Block tgt_block = Downcast(BaseInliner::VisitStmt_(op)); - if (op == producer_block_) { - tgt_block = BuildInlinedConsumerPredicate(tgt_block.get()); - block_reuse.Set(src_block, tgt_block); + Stmt VisitStmt_(const BlockRealizeNode* op) final { + Block src_block = op->block; + BlockRealize tgt_block_realize = Downcast(StmtMutator::VisitStmt_(op)); + if (src_block.get() == producer_block_) { + tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize); + block_reuse.Set(src_block, tgt_block_realize->block); } - return std::move(tgt_block); + return std::move(tgt_block_realize); } Stmt VisitStmt_(const BufferStoreNode* _store) final { diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 63117073d156..ca32c286abfe 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -113,10 +113,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) + T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < m) T.reads(matmul_reindex_pad_local[v0, v1, v2]) T.writes(matmul[T.int64(0), v1, v2]) - if v1 < m: - matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2] # fmt: on @@ -200,10 +200,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) + T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m) T.reads(matmul_reindex_pad_local[v0, v1, v2]) T.writes(matmul[0, v1, v2]) - if v1 < m: - matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] # fmt: on mod = tvm.IRModule({"main": func}) @@ -466,10 +466,10 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) + T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n) T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], lv3[T.int64(0), v1, v2]) T.writes(p_output0_intermediate[T.int64(0), v1, v2]) - if v1 < n: - p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2] + p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2] # fmt: on @@ -596,9 +596,9 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) T.reads(lv52[T.int64(0), v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) + T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n) T.writes(var_T_multiply_intermediate[v1, v2]) - if v1 < n: - var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])) + var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])) # fmt: on @@ -666,10 +666,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), v0 = T.axis.spatial(T.int64(1), ax0) v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1) + T.where(ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1 < m) T.reads(matmul_reindex_pad_local[v0, v1, v2]) T.writes(matmul[T.int64(0), v1, v2]) - if v1 < m: - matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2] # fmt: on diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py index 59ccfec55cc5..94d6a8e42ad3 100644 --- a/tests/python/dlight/test_gpu_matmul_tensorize.py +++ b/tests/python/dlight/test_gpu_matmul_tensorize.py @@ -254,10 +254,10 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) + T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15) T.reads(compute_reindex_pad_local[v0, v1, v2]) T.writes(compute[v1, v2]) - if v1 < m and v2 < 15: - compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2] + compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2] # fmt: on @@ -417,11 +417,11 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) + T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < n) T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]) T.writes(p_output0_intermediate[0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) - if v1 < n: - p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2] + p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2] # fmt: on @@ -690,11 +690,11 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) + T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < m) T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2]) T.writes(matmul_1[0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) - if v1 < m: - matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2] + matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2] # fmt: on @@ -831,10 +831,10 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) + T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size) T.reads(C_reindex_pad_shared[v0, v1, v2]) T.writes(C[v1, 0, v2]) - if v1 < batch_size: - C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] + C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] # fmt: on @@ -971,10 +971,10 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f v0 = T.axis.spatial(1, ax0_1) v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) + T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size) T.reads(C_reindex_pad_shared[v0, v1, v2]) T.writes(C[v1, 0, v2]) - if v1 < batch_size: - C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] + C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] if __name__ == "__main__": diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index da00f294ba0e..df8607e55127 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -856,11 +856,11 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 v3 = T.axis.spatial(1, 0) v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 4}) - if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 < 127: - compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py index 5cf59985d353..2f779612a72a 100644 --- a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py +++ b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py @@ -624,8 +624,8 @@ def elementwise_overcomputed_producer_reverse_inlined( for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - if vi < 127 and vj < 127: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + T.where(i < 127 and j < 127) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -652,8 +652,8 @@ def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined( with T.block("B"): vi = T.axis.spatial(128, i // 128) vj = T.axis.spatial(128, i % 128) - if vi < 127 and vj < 127: - C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + T.where(i < 16255 and i % 128 < 127) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 @T.prim_func @@ -678,8 +678,8 @@ def elementwise_overcomputed_producer_injective_load_reverse_inlined( for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): with T.block("B"): vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) - if vi * 16 + vm < 127 and vj * 16 + vn < 127: - C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 + T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127) + C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 @T.prim_func @@ -740,8 +740,7 @@ def elementwise_predicate_producer_inlined(a: T.handle, c: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(C[vi, vj]) - if vi < 127: - C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1) + C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1) # fmt: off @@ -1486,5 +1485,49 @@ def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) +def test_reverse_compute_inline_slicing_then_cachewrite(): + @T.prim_func + def before( + x: T.Buffer((1, 16, 7, 7), "float32"), + T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"), + ): + T_add = T.alloc_buffer((1, 16, 7, 7)) + for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1) + for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7): + with T.block("T_strided_slice_with_axes"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = T_add[ + v_ax0, v_ax1, v_ax2, v_ax3 + ] + + @T.prim_func + def after( + x: T.Buffer((1, 16, 7, 7), "float32"), + T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"), + ): + T_strided_slice_with_axes_global = T.alloc_buffer((1, 12, 7, 7)) + for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.where(ax1 < 12) + T_strided_slice_with_axes_global[v_ax0, v_ax1, v_ax2, v_ax3] = x[ + v_ax0, v_ax1, v_ax2, v_ax3 + ] + T.float32(1) + for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7): + with T.block("T_strided_slice_with_axes_global"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T_strided_slice_with_axes[v0, v1, v2, v3] = T_strided_slice_with_axes_global[ + v0, v1, v2, v3 + ] + + sch = tir.Schedule(before) + sch.reverse_compute_inline(sch.get_block("T_strided_slice_with_axes")) + sch.cache_write(sch.get_block("T_add"), 0, "global") + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()