From 9c26e843a42700e1f236b8937fa8d2068f027004 Mon Sep 17 00:00:00 2001 From: qelk123 <73232000+qelk123@users.noreply.github.com> Date: Sun, 26 Mar 2023 03:43:58 +0800 Subject: [PATCH] fix when binary search find invalid value (#95) * reopen pr * move invalid indices postprocess pass to lower_sparse_iter * refine * upd * lint * add comments and fix bug --- include/tvm/tir/sparse.h | 13 +- include/tvm/tir/transform.h | 2 +- python/tvm/script/tir/special_stmt.py | 9 +- python/tvm/sparse/lower.py | 6 +- python/tvm/tir/sparse.py | 7 +- python/tvm/tir/transform/transform.py | 9 +- src/printer/text_printer.h | 1 + src/printer/tir_text_printer.cc | 21 +++ src/printer/tvmscript_printer.cc | 5 + src/tir/ir/sparse.cc | 18 ++- src/tir/ir/specialize.cc | 2 +- src/tir/transforms/lower_sparse_iter.cc | 192 ++++++++++++++++-------- 12 files changed, 206 insertions(+), 79 deletions(-) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 04460ebd8..b91de0ca4 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -271,17 +271,22 @@ class SparseBufferNode : public BufferNode { PrimExpr GetNNZ() const; Buffer flattened; - + /*! + * \brief The default value in the sparse buffer. + */ + Optional default_value; void VisitAttrs(AttrVisitor* v) { BufferNode::VisitAttrs(v); v->Visit("axes", &axes); v->Visit("extra_storage", &extra_storage); v->Visit("flattened", &flattened); + v->Visit("default_value", &default_value); } bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { return BufferNode::SEqualReduce(other, equal) && equal(axes, other->axes) && - equal(extra_storage, other->extra_storage) && equal(flattened, other->flattened); + equal(extra_storage, other->extra_storage) && equal(flattened, other->flattened) && + equal(default_value, other->default_value); } void SHashReduce(SHashReducer hash_reduce) const { @@ -289,6 +294,7 @@ class SparseBufferNode : public BufferNode { hash_reduce(axes); hash_reduce(extra_storage); hash_reduce(flattened); + hash_reduce(default_value); } static constexpr const char* _type_key = "tir.sparse.SparseBuffer"; @@ -304,7 +310,8 @@ class SparseBufferNode : public BufferNode { class SparseBuffer : public Buffer { public: TVM_DLL explicit SparseBuffer(Var data, Array axes, DataType dtype, String name, - Optional extra_storage, Span span = Span()); + Optional extra_storage, + Optional default_value = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, Buffer, SparseBufferNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferNode); }; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index c2cc01af5..551cac8cc 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -629,7 +629,7 @@ TVM_DLL Pass RenormalizeSplitPattern(); * \brief Lower sparse iterations in Sparse TIR. * \return The pass. */ -TVM_DLL Pass LowerSparseIter(); +TVM_DLL Pass LowerSparseIter(bool check_invalid_binary_search = false); /*! * \brief Lower sparse buffers in Sparse TIR. diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 224e41298..6a6f72dc4 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -983,7 +983,6 @@ def preflattened_buffer( buffer_type="default", span=None, ): - param = None for key, value in self.context.func_buffer_map.items(): if value.same_as(postflattened): @@ -1194,6 +1193,7 @@ def match_sparse_buffer( axes: List[Axis], dtype: str = "float32", extra_storage: Optional[PrimExpr] = None, + default_value: Optional[PrimExpr] = None, span: Optional[Span] = None, ): if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: @@ -1214,7 +1214,7 @@ def match_sparse_buffer( storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type data = Var(buffer_name, PointerType(storage_type, "global"), span) buffer = tvm.tir.sparse.SparseBuffer( - data, axes, dtype, buffer_name, extra_storage, span + data, axes, dtype, buffer_name, extra_storage, default_value, span ) self.context.func_buffer_map[param] = buffer self.context.update_symbol(buffer_name, buffer, self.node) @@ -1235,6 +1235,7 @@ def alloc_sparse_buffer( axes: List[Axis], dtype: str = "float32", scope: str = "global", + default_value: Optional[PrimExpr] = None, span: Optional[Span] = None, ): if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: @@ -1245,7 +1246,9 @@ def alloc_sparse_buffer( buffer_name: str = self.node.lhs[0].id.name data = Var(buffer_name, PointerType(PrimType(dtype), scope), span) - buffer = tvm.tir.sparse.SparseBuffer(data, axes, dtype, buffer_name, 0, span) + buffer = tvm.tir.sparse.SparseBuffer( + data, axes, dtype, buffer_name, 0, default_value, span + ) if self.context.current_block_scope(): self.context.current_block_scope().alloc_buffers.append(buffer) else: diff --git a/python/tvm/sparse/lower.py b/python/tvm/sparse/lower.py index d18aae307..51742ef73 100644 --- a/python/tvm/sparse/lower.py +++ b/python/tvm/sparse/lower.py @@ -20,17 +20,19 @@ from tvm.tir.transform import LowerSparseBuffer, LowerSparseIter -def lower_sparse_iter(mod: IRModule): +def lower_sparse_iter(mod: IRModule, check_invalid_binary_search: bool = False): """Lower sparse iterators in Sparse TIR. Parameters ---------- mod : IRModule The IRModule to lower. + check_invalid_binary_search : bool + Whether check invalid indices made by binary search. """ if not isinstance(mod, IRModule): raise TypeError("Expected IRModule, but got {}".format(type(mod))) - return LowerSparseIter()(mod) + return LowerSparseIter(check_invalid_binary_search)(mod) def lower_sparse_buffer(mod: IRModule): diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index b64f0c370..44f8f0a61 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -298,6 +298,8 @@ class SparseBuffer(Buffer): The name of the sparse buffer extra_storage : Optional[PrimExpr] Required extra storage (e.g. for indptr) + default_value : Optional[PrimExpr] + The default value about missing value of the the sparse buffer span : Span """ @@ -306,10 +308,11 @@ class SparseBuffer(Buffer): dtype: str name: str extra_storage: Optional[PrimExpr] + default_value: Optional[PrimExpr] span: Span - def __init__(self, data, axes, dtype, name, extra_storage, span): - self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, data, axes, dtype, name, extra_storage, span) # type: ignore + def __init__(self, data, axes, dtype, name, extra_storage, default_value, span): + self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, data, axes, dtype, name, extra_storage, default_value, span) # type: ignore @tvm._ffi.register_object("tir.sparse.SpIterVar") diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 5113feaa2..785712dd2 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -806,15 +806,20 @@ def RenormalizeSplitPattern(): return _ffi_api.RenormalizeSplitPattern() # type: ignore -def LowerSparseIter(): +def LowerSparseIter(check_invalid_binary_search: bool = False): """Lower iterations in Sparse TIR + Parameters + ---------- + check_invalid_binary_search : bool + Whether check invalid indices made by binary search. + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerSparseIter() # type: ignore + return _ffi_api.LowerSparseIter(check_invalid_binary_search) # type: ignore def LowerSparseBuffer(): diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index c34c4a5b6..525daeefb 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -367,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const WhileNode* op) override; Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; + Doc VisitStmt_(const SparseIterationNode* op) override; Doc VisitStmtDefault_(const Object* op) override; Doc VisitType_(const PrimTypeNode* node) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fe829016b..10caaa3be 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -606,6 +606,27 @@ Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const SparseIterationNode* op) { + Doc doc; + doc << "sparse_iteration " << op->name << "("; + doc << Print(op->sp_iter_vars[0]->var); + for (int i = 1; i < static_cast(op->sp_iter_vars.size()); ++i) { + doc << "," << Print(op->sp_iter_vars[i]->var); + } + doc << ")"; + Doc body; + if (op->init.defined()) { + Doc init_block; + init_block << "with init()"; + init_block << PrintBody(op->init.value()); + body << init_block << Doc::NewLine(); + } + // Print body + body << Print(op->body); + doc << " {" << Doc::Indent(2, Doc::NewLine() << body) << Doc::NewLine() << "}"; + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { const auto* block_op = op->block.as(); // print block name and block vars diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index eee237f09..706495790 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -485,6 +485,11 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { if (sp_buf->extra_storage.defined()) { doc << ", extra_storage=" << Print(sp_buf->extra_storage.value()); } + + // default value + if (sp_buf->default_value.defined()) { + doc << ", default_value=" << Print(sp_buf->default_value.value()); + } // scope const auto* ptr_type = sp_buf->data->type_annotation.as(); ICHECK(ptr_type) << "Buffer variable is not of pointer type"; diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 42e49615b..9c892011d 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -241,7 +241,8 @@ PrimExpr SparseBufferNode::GetNNZ() const { return flattened->shape[0]; } /*! \brief Default constructor of SparseBuffer */ SparseBuffer::SparseBuffer(Var data, Array axes, DataType dtype, String name, - Optional extra_storage, Span span) { + Optional extra_storage, Optional default_value, + Span span) { ObjectPtr node = make_object(); CHECK_GT(static_cast(axes.size()), 0) << "ValueError: A SparseBuffer should have at least one dimension"; @@ -275,6 +276,13 @@ SparseBuffer::SparseBuffer(Var data, Array axes, DataType dtype, String na node->extra_storage = extra_storage; node->name = name; node->dtype = dtype; + if (!default_value) { + node->default_value = Cast(dtype, Integer(0)); + } else { + ICHECK(default_value.value()->dtype == dtype) + << "sparse buffer default value should match buffer data type"; + node->default_value = default_value; + } // collect shape Array shape; for (const Axis& axis : axes) { @@ -307,9 +315,10 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") .set_body_typed([](Var data, Array axes, DataType dtype, String name, - Optional extra_storage, Span span) { + Optional extra_storage, Optional default_value, + Span span) { return SparseBuffer(std::move(data), std::move(axes), std::move(dtype), std::move(name), - std::move(extra_storage), std::move(span)); + std::move(extra_storage), std::move(default_value), std::move(span)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -327,6 +336,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (op->extra_storage.defined()) { p->stream << ", " << op->extra_storage.value(); } + if (op->default_value.defined()) { + p->stream << ", " << op->default_value.value(); + } p->stream << ")"; }); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 3a8a1f494..e7445f329 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -252,7 +252,7 @@ class PrimFuncSpecializer : public StmtExprMutator { return buffer; } else { return SparseBuffer(sp_buf->data, std::move(axes), sp_buf->dtype, sp_buf->name, - sp_buf->extra_storage, sp_buf->span); + sp_buf->extra_storage, sp_buf->default_value, sp_buf->span); } } else { Array shape = diff --git a/src/tir/transforms/lower_sparse_iter.cc b/src/tir/transforms/lower_sparse_iter.cc index 8fb56fab5..018e1d184 100644 --- a/src/tir/transforms/lower_sparse_iter.cc +++ b/src/tir/transforms/lower_sparse_iter.cc @@ -324,31 +324,18 @@ class LowerSparseIterContext { std::vector stack_; }; -/*! - * \brief The binary search type. - */ -typedef enum { - /*! \brief `lower_bound` search. */ - kLower = 0U, - /*! - * \brief finds the element equals given value, and creates a `success` buffer to - * store successful state. - */ - kEqual = 1U, - /*! \brief `upper_bound` search. */ - kUpper = 2U, -} BinarySearchType; - /*! * \brief Lower sparse iterations by rewriting AST. */ class IterTransformer : public StmtExprMutator { public: explicit IterTransformer(Map axis_indptr_map, - Map axis_indices_map, const Array& sp_axes) + Map axis_indices_map, const Array& sp_axes, + bool check_invalid_binary_search) : axis_indptr_map_(std::move(axis_indptr_map)), axis_indices_map_(std::move(axis_indices_map)), - bsearch_blk_counter(0) { + bsearch_blk_counter(0), + check_invalid_binary_search_(check_invalid_binary_search) { CreateBaseDomMap(sp_axes); } @@ -503,6 +490,14 @@ class IterTransformer : public StmtExprMutator { return false; } }; + bool back_up_check_region = binary_search_vaild_check_region; + if (check_invalid_binary_search_) { + auto valid_check_flag = + sp_iteration->annotations.Get("binary_search_vaild_check").value_or(Bool(true)); + if (Downcast(valid_check_flag) == false) { + binary_search_vaild_check_region = false; + } + } int n_iters = static_cast(sp_iteration->sp_iter_vars.size()); Array loop_vars; @@ -633,6 +628,9 @@ class IterTransformer : public StmtExprMutator { // Create new block. Map annotations = sp_iteration->annotations; annotations.Set("sparse", Bool(true)); + if (binary_search_vaild_check_region && check_invalid_binary_search_) { + annotations.Set("binary_search_vaild_check", Bool(true)); + } Block block(/*iter_vars=*/info.block_iters, /*reads=*/reads_new, /*writes=*/writes_new, @@ -673,6 +671,9 @@ class IterTransformer : public StmtExprMutator { Map annotations; annotations.Set("sparse", Bool(true)); annotations.Set("preprocess", Bool(true)); + if (check_invalid_binary_search_) { + annotations.Set("is_binary_search_block", Bool(true)); + } Array reads, writes; if (i == static_cast(bsearch_block_info.size()) - 1) { // innermost @@ -713,6 +714,8 @@ class IterTransformer : public StmtExprMutator { } } + // restore binary search vaild check region + binary_search_vaild_check_region = back_up_check_region; // Exit the context. ctx_.ExitScope(); @@ -776,8 +779,7 @@ class IterTransformer : public StmtExprMutator { } offset = BinarySearch(indptr_buf, prefix_indices, Integer(0), - GetParentAxis(original_axis)->nnz + Integer(1), offset, kUpper, true) - .first; + GetParentAxis(original_axis)->nnz + Integer(1), offset, false, true); } Axis original_axis = GetAxisBeforeFuse(fused_axis->group[fused_axis->index]); if (!original_axis->IsSparse()) { @@ -846,33 +848,21 @@ class IterTransformer : public StmtExprMutator { * \param lb The lower bound (close) of the search range [lb, ub) * \param ub The upper bound (open) of the search range [lb, ub) * \param val The value to be searched. - * \param search_type The binary search type: kLower, kEqual or kUpper, - * If kLower, returns the leftmost index of the locations whose value is greater or - * equal to `val`. - * If kEqual, returns the leftmost index of the locations whose value is greater or - * equal to `val`, and creates a `success` buffer stores success state of the search ( - * whether there exists elements equals the given value.) - * If kUpper, returns the leftmost index of the locations whose value is greater - * than `val`. - * If no such index exists, returns `ub`. * \param minus_one Whether to minus one to the final result (when used together with - * `kUpper`, you will get the rightmost index of the suitable location to maintain + * `left=false`, you will get the rightmost index of the suitable location to maintain * ascending order). */ - std::pair> BinarySearch(SparseBuffer buf, - Array prefix_indices, PrimExpr lb, - PrimExpr ub, PrimExpr val, - BinarySearchType search_type, - bool minus_one = false) { + PrimExpr BinarySearch(SparseBuffer buf, Array prefix_indices, PrimExpr lb, PrimExpr ub, + PrimExpr val, bool left, bool minus_one = false) { /* Algorithm: - * - lower_bound (search_type == kLower || search_type == kEqual) + * - when left = true * - pre-condition * lb < ub, and the last dimension of buf is sorted. * - loop-invariant * low <= mid < high, buf[..., lb:low] < val, buf[..., high:ub] >= val * - post-condition * low = mid = high, buf[..., lb:low] < val, buf[..., high:ub] >= val - * - upper_bound (search_type == kUpper) + * - when left = false * - pre-condition * lb < ub, and the last dimension of buf is sorted. * - loop-invariant @@ -891,7 +881,7 @@ class IterTransformer : public StmtExprMutator { args.push_back(lb); args.push_back(ub); args.push_back(val); - args.push_back(Integer(static_cast(search_type))); + args.push_back(Bool(left)); args.push_back(Bool(minus_one)); if (bsearch_map_.count(args)) { return bsearch_map_[args]; @@ -945,13 +935,6 @@ class IterTransformer : public StmtExprMutator { String mid_buf_name = "mid_" + std::to_string(bsearch_blk_counter); SparseBuffer mid = SparseBuffer(Var(mid_buf_name, PointerType(PrimType(dtype), "global")), axes, dtype, mid_buf_name, Integer(0)); - String success_buf_name = "success_" + std::to_string(bsearch_blk_counter); - Optional success; - if (search_type == kEqual) { - success = SparseBuffer( - Var(success_buf_name, PointerType(PrimType(DataType{kDLInt, 8, 1}), "global")), axes, - DataType{kDLInt, 8, 1}, success_buf_name, Integer(0)); - } Stmt low_store = BufferStore(low, lb, {Integer(0)}); Stmt high_store = BufferStore(high, ub, {Integer(0)}); @@ -964,7 +947,6 @@ class IterTransformer : public StmtExprMutator { Array indices = prefix_indices; indices.push_back(mid_val); PrimExpr pivot = BufferLoad(buf, indices); - bool left = search_type == kLower || search_type == kEqual; PrimExpr pivot_cmp_cond = left ? (pivot < val) : (pivot > val); Stmt if_true = left ? BufferStore(low, mid_val + 1, {Integer(0)}) : BufferStore(high, mid_val, {Integer(0)}); @@ -978,10 +960,10 @@ class IterTransformer : public StmtExprMutator { body_stmts.push_back( BufferStore(mid, BufferLoad(mid, mid_indices) - Integer(1), mid_indices)); } - if (search_type == kEqual) { - body_stmts.push_back( - BufferStore(success.value(), - Select(mid_val != ub && pivot == val, Integer(1), Integer(0)), mid_indices)); + if (!binary_search_vaild_check_region && check_invalid_binary_search_) { + Stmt then_stmt = BufferStore(mid, -1, mid_indices); + PrimExpr if_stmt = (pivot != val || mid_val == ub); + body_stmts.push_back(IfThenElse(if_stmt, then_stmt)); } SeqStmt body(body_stmts); @@ -1002,11 +984,8 @@ class IterTransformer : public StmtExprMutator { BufferRegion write = BufferRegion(mid, write_regions); bsearch_structures.push_back( BinarySearchStructure({name, body, var_map, inv_var_map, {low, high}, read, write})); - std::pair> ret_val = { - mid_val, search_type == kEqual ? BufferLoad(success.value(), mid_indices) - : Optional(NullOpt)}; - bsearch_map_[args] = ret_val; - return ret_val; + bsearch_map_[args] = mid_val; + return mid_val; } /*! \brief Return indices viewed in a given buffer. */ @@ -1084,8 +1063,7 @@ class IterTransformer : public StmtExprMutator { extent = buf_axis->nnz_cols.value(); } new_index = - BinarySearch(indices_buf, indices_path, Integer(0), extent, coordinate, kLower) - .first; + BinarySearch(indices_buf, indices_path, Integer(0), extent, coordinate, true); } else { // it's dense axis. new_index = coordinate; @@ -1140,13 +1118,98 @@ class IterTransformer : public StmtExprMutator { Map axis_indptr_map_; // axis to indptr buffer map. Map axis_indices_map_; // axis to indices buffer map. std::unordered_map base_dom_map_; // The base dom map. - std::unordered_map>, StructuralHash, - StructuralEqual> + std::unordered_map bsearch_map_; // The map storing existing binary search keys and values. int bsearch_blk_counter; // Counter for generated binary search blocks. + bool binary_search_vaild_check_region = true; + bool check_invalid_binary_search_ = false; }; -PrimFunc LowerSparseIter(PrimFunc f) { +class InvalidIndicesPostProcess : public StmtExprMutator { + public: + InvalidIndicesPostProcess() {} + + private: + /*! \brief Visitor of block node. + * \note For block with attr "binary_search_vaild_check",just skip it. + For binary search block,collect mid buffer it writes. + For other block do indices post process. + */ + Stmt VisitStmt_(const BlockNode* op) final { + auto it = op->annotations.find("binary_search_vaild_check"); + if (it != op->annotations.end() && (Downcast((*it).second)->value)) { + return GetRef(op); + } + it = op->annotations.find("is_binary_search_block"); + if (it != op->annotations.end() && Downcast((*it).second)->value) { + for (auto i : op->writes) { + binary_search_buffer.insert(i->buffer.get()); + } + return GetRef(op); + } + auto ret = StmtExprMutator::VisitStmt_(op); + return ret; + } + + /*! \brief Visitor of buffer store node. + * \note Get default value expr for bufferload with invalid indices + */ + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + if (binary_search_buffer.count(op->buffer.get()) && + buffer_processed.count(op->buffer.get()) == 0) { + buffer_need_process.push_back(GetRef(op)); + find_mid_buffer++; + } else { + auto find_backup = find_mid_buffer; + for (auto i : op->indices) { + VisitExpr(i); + } + if (find_mid_buffer != find_backup) { + find_mid_buffer = find_backup; + auto sparse_buffer = Downcast(op->buffer); + return sparse_buffer->default_value.value(); + } + } + return StmtExprMutator::VisitExpr_(op); + } + + /*! \brief Visitor of buffer store node. + * \note Construct IfThenElse stmt for invalid indices bufferload + */ + Stmt VisitStmt_(const BufferStoreNode* op) final { + size_t original_size = buffer_need_process.size(); + PrimExpr value = VisitExpr(op->value); + size_t new_buffer_num = buffer_need_process.size() - original_size; + if (new_buffer_num) { + if (new_buffer_num > 1) { + auto bufferload = buffer_need_process[original_size]; + for (size_t i = original_size + 1; i < buffer_need_process.size(); i++) { + ICHECK(buffer_need_process[i].same_as(bufferload)) + << "current only allow same mid buffer load expr in one buffer store stmt"; + } + } + auto buffer_found = buffer_need_process.back(); + buffer_need_process.erase(buffer_need_process.end() - new_buffer_num, + buffer_need_process.end()); + buffer_processed.insert(buffer_found->buffer.get()); + PrimExpr if_stmt = (-1 != buffer_found); + auto new_stmt = + IfThenElse(if_stmt, StmtExprMutator::VisitStmt_(op), + BufferStore(op->buffer, ana.Simplify(value), op->indices, op->span)); + buffer_processed.erase(buffer_found->buffer.get()); + return new_stmt; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + std::vector buffer_need_process; + std::set buffer_processed; + std::set binary_search_buffer; + arith::Analyzer ana; + int find_mid_buffer = 0; +}; + +PrimFunc LowerSparseIter(PrimFunc f, bool check_invalid_binary_search) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f) && SparseTIRLevel(f) == 2) { PrimFuncNode* fptr = f.CopyOnWrite(); @@ -1156,7 +1219,8 @@ PrimFunc LowerSparseIter(PrimFunc f) { std::tie(axis_indptr_map, axis_indices_map, fptr->buffer_map, fptr->sp_axes, buf_doms) = UpdateMetadata(f); // Step 2. Lower iterations. - IterTransformer lower_sparse(axis_indptr_map, axis_indices_map, fptr->sp_axes); + IterTransformer lower_sparse(axis_indptr_map, axis_indices_map, fptr->sp_axes, + check_invalid_binary_search); Stmt body = lower_sparse(std::move(fptr->body)); // Step 3. Wrap with root block, insert bsearch blocks and allocated buffers. if (!lower_sparse.bsearch_structures.empty()) { @@ -1175,6 +1239,10 @@ PrimFunc LowerSparseIter(PrimFunc f) { Map new_attr_dict = fptr->attrs->dict; new_attr_dict.Set("sparse_tir_level", Integer(1)); fptr->attrs = DictAttrs(new_attr_dict); + // Step 5. postprocess bufferload with possible invalid indices + if (check_invalid_binary_search) { + fptr->body = InvalidIndicesPostProcess()(std::move(fptr->body)); + } return f; } else { return f; @@ -1186,9 +1254,9 @@ namespace transform { /*! * \brief The lowering pass from TIR to Sparse TIR. */ -Pass LowerSparseIter() { +Pass LowerSparseIter(bool check_invalid_binary_search) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return LowerSparseIter(std::move(f)); + return LowerSparseIter(std::move(f), check_invalid_binary_search); }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseIter", {}); }