From 5b11e0f308b833f16489cc8e1dc1ff793ba8e4f1 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 23 Mar 2023 19:19:01 -0700 Subject: [PATCH] Add an `exact_match` argument to `BinarySearch` function in Sparse Iteration Lowering pass (#96) * udp * upd * lint --- src/tir/transforms/lower_sparse_iter.cc | 77 +++++++++++++++++++++---- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/lower_sparse_iter.cc b/src/tir/transforms/lower_sparse_iter.cc index d81715dfb..8fb56fab5 100644 --- a/src/tir/transforms/lower_sparse_iter.cc +++ b/src/tir/transforms/lower_sparse_iter.cc @@ -324,6 +324,21 @@ class LowerSparseIterContext { std::vector stack_; }; +/*! + * \brief The binary search type. + */ +typedef enum { + /*! \brief `lower_bound` search. */ + kLower = 0U, + /*! + * \brief finds the element equals given value, and creates a `success` buffer to + * store successful state. + */ + kEqual = 1U, + /*! \brief `upper_bound` search. */ + kUpper = 2U, +} BinarySearchType; + /*! * \brief Lower sparse iterations by rewriting AST. */ @@ -761,7 +776,8 @@ class IterTransformer : public StmtExprMutator { } offset = BinarySearch(indptr_buf, prefix_indices, Integer(0), - GetParentAxis(original_axis)->nnz + Integer(1), offset, false, true); + GetParentAxis(original_axis)->nnz + Integer(1), offset, kUpper, true) + .first; } Axis original_axis = GetAxisBeforeFuse(fused_axis->group[fused_axis->index]); if (!original_axis->IsSparse()) { @@ -823,19 +839,40 @@ class IterTransformer : public StmtExprMutator { /*! * \brief Perform binary search inside TIR. - * \return The buffer (size=1) containing the binary search result. + * \param buf The sparse buffer to be searched (must be sorted in ascending order on the last + * dimension). + * \param prefix_indices The prefix indices of the sparse buffer from the first dimension to the + * d-1 dimension (suppose `buf` is d-dimensional). + * \param lb The lower bound (close) of the search range [lb, ub) + * \param ub The upper bound (open) of the search range [lb, ub) + * \param val The value to be searched. + * \param search_type The binary search type: kLower, kEqual or kUpper, + * If kLower, returns the leftmost index of the locations whose value is greater or + * equal to `val`. + * If kEqual, returns the leftmost index of the locations whose value is greater or + * equal to `val`, and creates a `success` buffer stores success state of the search ( + * whether there exists elements equals the given value.) + * If kUpper, returns the leftmost index of the locations whose value is greater + * than `val`. + * If no such index exists, returns `ub`. + * \param minus_one Whether to minus one to the final result (when used together with + * `kUpper`, you will get the rightmost index of the suitable location to maintain + * ascending order). */ - PrimExpr BinarySearch(SparseBuffer buf, Array prefix_indices, PrimExpr lb, PrimExpr ub, - PrimExpr val, bool left, bool minus_one = false) { + std::pair> BinarySearch(SparseBuffer buf, + Array prefix_indices, PrimExpr lb, + PrimExpr ub, PrimExpr val, + BinarySearchType search_type, + bool minus_one = false) { /* Algorithm: - * - when left = true + * - lower_bound (search_type == kLower || search_type == kEqual) * - pre-condition * lb < ub, and the last dimension of buf is sorted. * - loop-invariant * low <= mid < high, buf[..., lb:low] < val, buf[..., high:ub] >= val * - post-condition * low = mid = high, buf[..., lb:low] < val, buf[..., high:ub] >= val - * - when left = false + * - upper_bound (search_type == kUpper) * - pre-condition * lb < ub, and the last dimension of buf is sorted. * - loop-invariant @@ -854,7 +891,7 @@ class IterTransformer : public StmtExprMutator { args.push_back(lb); args.push_back(ub); args.push_back(val); - args.push_back(Bool(left)); + args.push_back(Integer(static_cast(search_type))); args.push_back(Bool(minus_one)); if (bsearch_map_.count(args)) { return bsearch_map_[args]; @@ -908,6 +945,13 @@ class IterTransformer : public StmtExprMutator { String mid_buf_name = "mid_" + std::to_string(bsearch_blk_counter); SparseBuffer mid = SparseBuffer(Var(mid_buf_name, PointerType(PrimType(dtype), "global")), axes, dtype, mid_buf_name, Integer(0)); + String success_buf_name = "success_" + std::to_string(bsearch_blk_counter); + Optional success; + if (search_type == kEqual) { + success = SparseBuffer( + Var(success_buf_name, PointerType(PrimType(DataType{kDLInt, 8, 1}), "global")), axes, + DataType{kDLInt, 8, 1}, success_buf_name, Integer(0)); + } Stmt low_store = BufferStore(low, lb, {Integer(0)}); Stmt high_store = BufferStore(high, ub, {Integer(0)}); @@ -920,6 +964,7 @@ class IterTransformer : public StmtExprMutator { Array indices = prefix_indices; indices.push_back(mid_val); PrimExpr pivot = BufferLoad(buf, indices); + bool left = search_type == kLower || search_type == kEqual; PrimExpr pivot_cmp_cond = left ? (pivot < val) : (pivot > val); Stmt if_true = left ? BufferStore(low, mid_val + 1, {Integer(0)}) : BufferStore(high, mid_val, {Integer(0)}); @@ -933,6 +978,11 @@ class IterTransformer : public StmtExprMutator { body_stmts.push_back( BufferStore(mid, BufferLoad(mid, mid_indices) - Integer(1), mid_indices)); } + if (search_type == kEqual) { + body_stmts.push_back( + BufferStore(success.value(), + Select(mid_val != ub && pivot == val, Integer(1), Integer(0)), mid_indices)); + } SeqStmt body(body_stmts); String name = "binary_search_block_" + std::to_string(bsearch_blk_counter); @@ -952,8 +1002,11 @@ class IterTransformer : public StmtExprMutator { BufferRegion write = BufferRegion(mid, write_regions); bsearch_structures.push_back( BinarySearchStructure({name, body, var_map, inv_var_map, {low, high}, read, write})); - bsearch_map_[args] = mid_val; - return mid_val; + std::pair> ret_val = { + mid_val, search_type == kEqual ? BufferLoad(success.value(), mid_indices) + : Optional(NullOpt)}; + bsearch_map_[args] = ret_val; + return ret_val; } /*! \brief Return indices viewed in a given buffer. */ @@ -1031,7 +1084,8 @@ class IterTransformer : public StmtExprMutator { extent = buf_axis->nnz_cols.value(); } new_index = - BinarySearch(indices_buf, indices_path, Integer(0), extent, coordinate, true); + BinarySearch(indices_buf, indices_path, Integer(0), extent, coordinate, kLower) + .first; } else { // it's dense axis. new_index = coordinate; @@ -1086,7 +1140,8 @@ class IterTransformer : public StmtExprMutator { Map axis_indptr_map_; // axis to indptr buffer map. Map axis_indices_map_; // axis to indices buffer map. std::unordered_map base_dom_map_; // The base dom map. - std::unordered_map + std::unordered_map>, StructuralHash, + StructuralEqual> bsearch_map_; // The map storing existing binary search keys and values. int bsearch_blk_counter; // Counter for generated binary search blocks. };