Skip to content

Commit

Permalink
Add an exact_match argument to BinarySearch function in Sparse It…
Browse files Browse the repository at this point in the history
…eration Lowering pass (#96)

* udp

* upd

* lint
  • Loading branch information
yzh119 authored Mar 24, 2023
1 parent ca59cbe commit 5b11e0f
Showing 1 changed file with 66 additions and 11 deletions.
77 changes: 66 additions & 11 deletions src/tir/transforms/lower_sparse_iter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,21 @@ class LowerSparseIterContext {
std::vector<Info> stack_;
};

/*!
* \brief The binary search type.
*/
typedef enum {
/*! \brief `lower_bound` search. */
kLower = 0U,
/*!
* \brief finds the element equals given value, and creates a `success` buffer to
* store successful state.
*/
kEqual = 1U,
/*! \brief `upper_bound` search. */
kUpper = 2U,
} BinarySearchType;

/*!
* \brief Lower sparse iterations by rewriting AST.
*/
Expand Down Expand Up @@ -761,7 +776,8 @@ class IterTransformer : public StmtExprMutator {
}
offset =
BinarySearch(indptr_buf, prefix_indices, Integer(0),
GetParentAxis(original_axis)->nnz + Integer(1), offset, false, true);
GetParentAxis(original_axis)->nnz + Integer(1), offset, kUpper, true)
.first;
}
Axis original_axis = GetAxisBeforeFuse(fused_axis->group[fused_axis->index]);
if (!original_axis->IsSparse()) {
Expand Down Expand Up @@ -823,19 +839,40 @@ class IterTransformer : public StmtExprMutator {

/*!
* \brief Perform binary search inside TIR.
* \return The buffer (size=1) containing the binary search result.
* \param buf The sparse buffer to be searched (must be sorted in ascending order on the last
* dimension).
* \param prefix_indices The prefix indices of the sparse buffer from the first dimension to the
* d-1 dimension (suppose `buf` is d-dimensional).
* \param lb The lower bound (close) of the search range [lb, ub)
* \param ub The upper bound (open) of the search range [lb, ub)
* \param val The value to be searched.
* \param search_type The binary search type: kLower, kEqual or kUpper,
* If kLower, returns the leftmost index of the locations whose value is greater or
* equal to `val`.
* If kEqual, returns the leftmost index of the locations whose value is greater or
* equal to `val`, and creates a `success` buffer stores success state of the search (
* whether there exists elements equals the given value.)
* If kUpper, returns the leftmost index of the locations whose value is greater
* than `val`.
* If no such index exists, returns `ub`.
* \param minus_one Whether to minus one to the final result (when used together with
* `kUpper`, you will get the rightmost index of the suitable location to maintain
* ascending order).
*/
PrimExpr BinarySearch(SparseBuffer buf, Array<PrimExpr> prefix_indices, PrimExpr lb, PrimExpr ub,
PrimExpr val, bool left, bool minus_one = false) {
std::pair<PrimExpr, Optional<PrimExpr>> BinarySearch(SparseBuffer buf,
Array<PrimExpr> prefix_indices, PrimExpr lb,
PrimExpr ub, PrimExpr val,
BinarySearchType search_type,
bool minus_one = false) {
/* Algorithm:
* - when left = true
* - lower_bound (search_type == kLower || search_type == kEqual)
* - pre-condition
* lb < ub, and the last dimension of buf is sorted.
* - loop-invariant
* low <= mid < high, buf[..., lb:low] < val, buf[..., high:ub] >= val
* - post-condition
* low = mid = high, buf[..., lb:low] < val, buf[..., high:ub] >= val
* - when left = false
* - upper_bound (search_type == kUpper)
* - pre-condition
* lb < ub, and the last dimension of buf is sorted.
* - loop-invariant
Expand All @@ -854,7 +891,7 @@ class IterTransformer : public StmtExprMutator {
args.push_back(lb);
args.push_back(ub);
args.push_back(val);
args.push_back(Bool(left));
args.push_back(Integer(static_cast<int>(search_type)));
args.push_back(Bool(minus_one));
if (bsearch_map_.count(args)) {
return bsearch_map_[args];
Expand Down Expand Up @@ -908,6 +945,13 @@ class IterTransformer : public StmtExprMutator {
String mid_buf_name = "mid_" + std::to_string(bsearch_blk_counter);
SparseBuffer mid = SparseBuffer(Var(mid_buf_name, PointerType(PrimType(dtype), "global")), axes,
dtype, mid_buf_name, Integer(0));
String success_buf_name = "success_" + std::to_string(bsearch_blk_counter);
Optional<SparseBuffer> success;
if (search_type == kEqual) {
success = SparseBuffer(
Var(success_buf_name, PointerType(PrimType(DataType{kDLInt, 8, 1}), "global")), axes,
DataType{kDLInt, 8, 1}, success_buf_name, Integer(0));
}

Stmt low_store = BufferStore(low, lb, {Integer(0)});
Stmt high_store = BufferStore(high, ub, {Integer(0)});
Expand All @@ -920,6 +964,7 @@ class IterTransformer : public StmtExprMutator {
Array<PrimExpr> indices = prefix_indices;
indices.push_back(mid_val);
PrimExpr pivot = BufferLoad(buf, indices);
bool left = search_type == kLower || search_type == kEqual;
PrimExpr pivot_cmp_cond = left ? (pivot < val) : (pivot > val);
Stmt if_true = left ? BufferStore(low, mid_val + 1, {Integer(0)})
: BufferStore(high, mid_val, {Integer(0)});
Expand All @@ -933,6 +978,11 @@ class IterTransformer : public StmtExprMutator {
body_stmts.push_back(
BufferStore(mid, BufferLoad(mid, mid_indices) - Integer(1), mid_indices));
}
if (search_type == kEqual) {
body_stmts.push_back(
BufferStore(success.value(),
Select(mid_val != ub && pivot == val, Integer(1), Integer(0)), mid_indices));
}
SeqStmt body(body_stmts);

String name = "binary_search_block_" + std::to_string(bsearch_blk_counter);
Expand All @@ -952,8 +1002,11 @@ class IterTransformer : public StmtExprMutator {
BufferRegion write = BufferRegion(mid, write_regions);
bsearch_structures.push_back(
BinarySearchStructure({name, body, var_map, inv_var_map, {low, high}, read, write}));
bsearch_map_[args] = mid_val;
return mid_val;
std::pair<PrimExpr, Optional<PrimExpr>> ret_val = {
mid_val, search_type == kEqual ? BufferLoad(success.value(), mid_indices)
: Optional<PrimExpr>(NullOpt)};
bsearch_map_[args] = ret_val;
return ret_val;
}

/*! \brief Return indices viewed in a given buffer. */
Expand Down Expand Up @@ -1031,7 +1084,8 @@ class IterTransformer : public StmtExprMutator {
extent = buf_axis->nnz_cols.value();
}
new_index =
BinarySearch(indices_buf, indices_path, Integer(0), extent, coordinate, true);
BinarySearch(indices_buf, indices_path, Integer(0), extent, coordinate, kLower)
.first;
} else {
// it's dense axis.
new_index = coordinate;
Expand Down Expand Up @@ -1086,7 +1140,8 @@ class IterTransformer : public StmtExprMutator {
Map<Axis, SparseBuffer> axis_indptr_map_; // axis to indptr buffer map.
Map<Axis, SparseBuffer> axis_indices_map_; // axis to indices buffer map.
std::unordered_map<const VarNode*, arith::IntSet> base_dom_map_; // The base dom map.
std::unordered_map<ObjectRef, PrimExpr, StructuralHash, StructuralEqual>
std::unordered_map<ObjectRef, std::pair<PrimExpr, Optional<PrimExpr>>, StructuralHash,
StructuralEqual>
bsearch_map_; // The map storing existing binary search keys and values.
int bsearch_blk_counter; // Counter for generated binary search blocks.
};
Expand Down

0 comments on commit 5b11e0f

Please sign in to comment.