Skip to content

Commit

Permalink
fix when binary search find invalid value (#95)
Browse files Browse the repository at this point in the history
* reopen pr

* move invalid indices postprocess pass to lower_sparse_iter

* refine

* upd

* lint

* add comments and fix bug
  • Loading branch information
qelk123 authored Mar 25, 2023
1 parent 5b11e0f commit 9c26e84
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 79 deletions.
13 changes: 10 additions & 3 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,24 +271,30 @@ class SparseBufferNode : public BufferNode {
PrimExpr GetNNZ() const;

Buffer flattened;

/*!
* \brief The default value in the sparse buffer.
*/
Optional<PrimExpr> default_value;
void VisitAttrs(AttrVisitor* v) {
BufferNode::VisitAttrs(v);
v->Visit("axes", &axes);
v->Visit("extra_storage", &extra_storage);
v->Visit("flattened", &flattened);
v->Visit("default_value", &default_value);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return BufferNode::SEqualReduce(other, equal) && equal(axes, other->axes) &&
equal(extra_storage, other->extra_storage) && equal(flattened, other->flattened);
equal(extra_storage, other->extra_storage) && equal(flattened, other->flattened) &&
equal(default_value, other->default_value);
}

void SHashReduce(SHashReducer hash_reduce) const {
BufferNode::SHashReduce(hash_reduce);
hash_reduce(axes);
hash_reduce(extra_storage);
hash_reduce(flattened);
hash_reduce(default_value);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
Expand All @@ -304,7 +310,8 @@ class SparseBufferNode : public BufferNode {
class SparseBuffer : public Buffer {
public:
TVM_DLL explicit SparseBuffer(Var data, Array<Axis> axes, DataType dtype, String name,
Optional<PrimExpr> extra_storage, Span span = Span());
Optional<PrimExpr> extra_storage,
Optional<PrimExpr> default_value = NullOpt, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, Buffer, SparseBufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ TVM_DLL Pass RenormalizeSplitPattern();
* \brief Lower sparse iterations in Sparse TIR.
* \return The pass.
*/
TVM_DLL Pass LowerSparseIter();
TVM_DLL Pass LowerSparseIter(bool check_invalid_binary_search = false);

/*!
* \brief Lower sparse buffers in Sparse TIR.
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,6 @@ def preflattened_buffer(
buffer_type="default",
span=None,
):

param = None
for key, value in self.context.func_buffer_map.items():
if value.same_as(postflattened):
Expand Down Expand Up @@ -1194,6 +1193,7 @@ def match_sparse_buffer(
axes: List[Axis],
dtype: str = "float32",
extra_storage: Optional[PrimExpr] = None,
default_value: Optional[PrimExpr] = None,
span: Optional[Span] = None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
Expand All @@ -1214,7 +1214,7 @@ def match_sparse_buffer(
storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
data = Var(buffer_name, PointerType(storage_type, "global"), span)
buffer = tvm.tir.sparse.SparseBuffer(
data, axes, dtype, buffer_name, extra_storage, span
data, axes, dtype, buffer_name, extra_storage, default_value, span
)
self.context.func_buffer_map[param] = buffer
self.context.update_symbol(buffer_name, buffer, self.node)
Expand All @@ -1235,6 +1235,7 @@ def alloc_sparse_buffer(
axes: List[Axis],
dtype: str = "float32",
scope: str = "global",
default_value: Optional[PrimExpr] = None,
span: Optional[Span] = None,
):
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
Expand All @@ -1245,7 +1246,9 @@ def alloc_sparse_buffer(
buffer_name: str = self.node.lhs[0].id.name

data = Var(buffer_name, PointerType(PrimType(dtype), scope), span)
buffer = tvm.tir.sparse.SparseBuffer(data, axes, dtype, buffer_name, 0, span)
buffer = tvm.tir.sparse.SparseBuffer(
data, axes, dtype, buffer_name, 0, default_value, span
)
if self.context.current_block_scope():
self.context.current_block_scope().alloc_buffers.append(buffer)
else:
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/sparse/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@
from tvm.tir.transform import LowerSparseBuffer, LowerSparseIter


def lower_sparse_iter(mod: IRModule):
def lower_sparse_iter(mod: IRModule, check_invalid_binary_search: bool = False):
"""Lower sparse iterators in Sparse TIR.
Parameters
----------
mod : IRModule
The IRModule to lower.
check_invalid_binary_search : bool
Whether check invalid indices made by binary search.
"""
if not isinstance(mod, IRModule):
raise TypeError("Expected IRModule, but got {}".format(type(mod)))
return LowerSparseIter()(mod)
return LowerSparseIter(check_invalid_binary_search)(mod)


def lower_sparse_buffer(mod: IRModule):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class SparseBuffer(Buffer):
The name of the sparse buffer
extra_storage : Optional[PrimExpr]
Required extra storage (e.g. for indptr)
default_value : Optional[PrimExpr]
The default value about missing value of the the sparse buffer
span : Span
"""

Expand All @@ -306,10 +308,11 @@ class SparseBuffer(Buffer):
dtype: str
name: str
extra_storage: Optional[PrimExpr]
default_value: Optional[PrimExpr]
span: Span

def __init__(self, data, axes, dtype, name, extra_storage, span):
self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, data, axes, dtype, name, extra_storage, span) # type: ignore
def __init__(self, data, axes, dtype, name, extra_storage, default_value, span):
self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, data, axes, dtype, name, extra_storage, default_value, span) # type: ignore


@tvm._ffi.register_object("tir.sparse.SpIterVar")
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,20 @@ def RenormalizeSplitPattern():
return _ffi_api.RenormalizeSplitPattern() # type: ignore


def LowerSparseIter():
def LowerSparseIter(check_invalid_binary_search: bool = False):
"""Lower iterations in Sparse TIR
Parameters
----------
check_invalid_binary_search : bool
Whether check invalid indices made by binary search.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerSparseIter() # type: ignore
return _ffi_api.LowerSparseIter(check_invalid_binary_search) # type: ignore


def LowerSparseBuffer():
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const WhileNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmt_(const BlockRealizeNode* op) override;
Doc VisitStmt_(const SparseIterationNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

Doc VisitType_(const PrimTypeNode* node) override;
Expand Down
21 changes: 21 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,27 @@ Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const SparseIterationNode* op) {
Doc doc;
doc << "sparse_iteration " << op->name << "(";
doc << Print(op->sp_iter_vars[0]->var);
for (int i = 1; i < static_cast<int>(op->sp_iter_vars.size()); ++i) {
doc << "," << Print(op->sp_iter_vars[i]->var);
}
doc << ")";
Doc body;
if (op->init.defined()) {
Doc init_block;
init_block << "with init()";
init_block << PrintBody(op->init.value());
body << init_block << Doc::NewLine();
}
// Print body
body << Print(op->body);
doc << " {" << Doc::Indent(2, Doc::NewLine() << body) << Doc::NewLine() << "}";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
const auto* block_op = op->block.as<BlockNode>();
// print block name and block vars
Expand Down
5 changes: 5 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
if (sp_buf->extra_storage.defined()) {
doc << ", extra_storage=" << Print(sp_buf->extra_storage.value());
}

// default value
if (sp_buf->default_value.defined()) {
doc << ", default_value=" << Print(sp_buf->default_value.value());
}
// scope
const auto* ptr_type = sp_buf->data->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "Buffer variable is not of pointer type";
Expand Down
18 changes: 15 additions & 3 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ PrimExpr SparseBufferNode::GetNNZ() const { return flattened->shape[0]; }

/*! \brief Default constructor of SparseBuffer */
SparseBuffer::SparseBuffer(Var data, Array<Axis> axes, DataType dtype, String name,
Optional<PrimExpr> extra_storage, Span span) {
Optional<PrimExpr> extra_storage, Optional<PrimExpr> default_value,
Span span) {
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
CHECK_GT(static_cast<int>(axes.size()), 0)
<< "ValueError: A SparseBuffer should have at least one dimension";
Expand Down Expand Up @@ -275,6 +276,13 @@ SparseBuffer::SparseBuffer(Var data, Array<Axis> axes, DataType dtype, String na
node->extra_storage = extra_storage;
node->name = name;
node->dtype = dtype;
if (!default_value) {
node->default_value = Cast(dtype, Integer(0));
} else {
ICHECK(default_value.value()->dtype == dtype)
<< "sparse buffer default value should match buffer data type";
node->default_value = default_value;
}
// collect shape
Array<PrimExpr> shape;
for (const Axis& axis : axes) {
Expand Down Expand Up @@ -307,9 +315,10 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
.set_body_typed([](Var data, Array<Axis> axes, DataType dtype, String name,
Optional<PrimExpr> extra_storage, Span span) {
Optional<PrimExpr> extra_storage, Optional<PrimExpr> default_value,
Span span) {
return SparseBuffer(std::move(data), std::move(axes), std::move(dtype), std::move(name),
std::move(extra_storage), std::move(span));
std::move(extra_storage), std::move(default_value), std::move(span));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -327,6 +336,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
if (op->extra_storage.defined()) {
p->stream << ", " << op->extra_storage.value();
}
if (op->default_value.defined()) {
p->stream << ", " << op->default_value.value();
}
p->stream << ")";
});

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class PrimFuncSpecializer : public StmtExprMutator {
return buffer;
} else {
return SparseBuffer(sp_buf->data, std::move(axes), sp_buf->dtype, sp_buf->name,
sp_buf->extra_storage, sp_buf->span);
sp_buf->extra_storage, sp_buf->default_value, sp_buf->span);
}
} else {
Array<PrimExpr> shape =
Expand Down
Loading

0 comments on commit 9c26e84

Please sign in to comment.