Skip to content

Commit

Permalink
rename to sp_iter (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Mar 6, 2023
1 parent f0dac6d commit ca59cbe
Show file tree
Hide file tree
Showing 17 changed files with 188 additions and 78 deletions.
4 changes: 2 additions & 2 deletions examples/blocksparse/bsr_spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def func(
B = T.match_sparse_buffer(b, (J_detach, BJ, F), "float16")
C = T.match_sparse_buffer(c, (I, BI, F), "float16")

with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [
with T.sp_iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [
i,
bi,
bj,
Expand Down Expand Up @@ -87,7 +87,7 @@ def func(
B = T.match_sparse_buffer(b, (J_detach, BJ, F), "float16")
C = T.match_sparse_buffer(c, (I_detach, BI, F), "float16")

with T.iter([O, I, BI, BJ, F, J], "SSSRSR", "bsrmm") as [
with T.sp_iter([O, I, BI, BJ, F, J], "SSSRSR", "bsrmm") as [
o,
i,
bi,
Expand Down
12 changes: 6 additions & 6 deletions examples/flash_graph_attention/fusedmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,26 @@ def fusedmm(
)
softmax = T.alloc_sparse_buffer([I, J], "float32")
# Q^T * K
with T.iter([I, J, F], "SSR", "sddmm") as [i, j, f]:
with T.sp_iter([I, J, F], "SSR", "sddmm") as [i, j, f]:
with T.init():
score[i, j] = T.float32(0)
score[i, j] += Q[i, f] * K[j, f]

# softmax
with T.iter([I], "S", "softmax") as [i]:
with T.iter([J], "R", "computer_max") as [j]:
with T.sp_iter([I], "S", "softmax") as [i]:
with T.sp_iter([J], "R", "computer_max") as [j]:
with T.init():
temp[i] = T.min_value("float32")
temp[i] = T.max(temp[i], score[i, j])
with T.iter([J], "R", "sum_of_exp") as [j]:
with T.sp_iter([J], "R", "sum_of_exp") as [j]:
with T.init():
temp1[i] = T.float32(0)
temp1[i] += T.exp(score[i, j] - temp[i], dtype="float32")
with T.iter([J], "S", "normalize") as [j]:
with T.sp_iter([J], "S", "normalize") as [j]:
softmax[i, j] = T.exp(score[i, j], dtype="float32") / temp1[i]

# softmax * V
with T.iter([I, J, F], "SRS", "spmm") as [i, j, f]:
with T.sp_iter([I, J, F], "SRS", "spmm") as [i, j, f]:
with T.init():
O[i, f] = T.float32(0)
O[i, f] = O[i, f] + softmax[i, j] * V[j, f]
Expand Down
4 changes: 2 additions & 2 deletions examples/rgms/rgcn/bench_rgcn_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def func(
Y = T.match_sparse_buffer(y, (I_detach, F_out), dtype)
WX = T.match_sparse_buffer(wx, (R, I, J, F_out), dtype)

with T.iter([R, I, J, F_out, F_in], "SSSSR", "rgcn-hetero-forward_wx") as [
with T.sp_iter([R, I, J, F_out, F_in], "SSSSR", "rgcn-hetero-forward_wx") as [
r,
i,
j,
Expand All @@ -601,7 +601,7 @@ def func(
WX[r, i, j, fo] = T.cast(0, dtype)
WX[r, i, j, fo] += T.cast(X[j, fi] * W[r, fi, fo], dtype)

with T.iter([R, I, J, F_out], "SSRS", "rgcn-hetero-forward") as [r, i, j, fo]:
with T.sp_iter([R, I, J, F_out], "SSRS", "rgcn-hetero-forward") as [r, i, j, fo]:
with T.init():
Y[i, fo] = T.cast(0, dtype)
Y[i, fo] = Y[i, fo] + T.cast(A[r, i, j], dtype) * WX[r, i, j, fo]
Expand Down
2 changes: 1 addition & 1 deletion examples/rgms/rgcn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def rgcn_hetero_forward(
W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32")
X = T.match_sparse_buffer(x, (J_detach, F_in), "float32")
Y = T.match_sparse_buffer(y, (I_detach, F_out), "float32")
with T.iter([F_out, R, I, J, F_in], "SSSRR", "rgcn-hetero-forward") as [fo, r, i, j, fi]:
with T.sp_iter([F_out, R, I, J, F_in], "SSSRR", "rgcn-hetero-forward") as [fo, r, i, j, fi]:
with T.init():
Y[i, fo] = 0.0
Y[i, fo] = Y[i, fo] + A[r, i, j] * W[r, fo, fi] * X[j, fi]
Expand Down
4 changes: 2 additions & 2 deletions examples/rgms/sparse_conv/rgms.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def func(
Y = T.match_sparse_buffer(y, (I_detach, F_out), "float16")
WX = T.match_sparse_buffer(wx, (R, I, J, F_out), "float16")

with T.iter([R, I, J, F_out, F_in], "SSSSR", "rgcn-hetero-forward_wx") as [
with T.sp_iter([R, I, J, F_out, F_in], "SSSSR", "rgcn-hetero-forward_wx") as [
r,
i,
j,
Expand All @@ -592,7 +592,7 @@ def func(
WX[r, i, j, fo] = T.float16(0)
WX[r, i, j, fo] += X[j, fi] * W[r, fi, fo]

with T.iter([R, I, J, F_out], "SSRS", "rgcn-hetero-forward") as [r, i, j, fo]:
with T.sp_iter([R, I, J, F_out], "SSRS", "rgcn-hetero-forward") as [r, i, j, fo]:
with T.init():
Y[i, fo] = T.float16(0)
Y[i, fo] = Y[i, fo] + A[r, i, j] * WX[r, i, j, fo]
Expand Down
2 changes: 1 addition & 1 deletion examples/sddmm/bench_sddmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def func(
B = T.match_sparse_buffer(b, (J_detach, K), "float32")
C = T.match_sparse_buffer(c, (I, J), "float32")

with T.iter([I, J, K], "SSR", "sddmm") as [i, j, k]:
with T.sp_iter([I, J, K], "SSR", "sddmm") as [i, j, k]:
with T.init():
C[i, j] = 0.0
C[i, j] = C[i, j] + A[i, k] * B[j, k]
Expand Down
2 changes: 1 addition & 1 deletion examples/spmm/bench_spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def csrmm(
A = T.match_sparse_buffer(a, (I, J), "float32")
B = T.match_sparse_buffer(b, (J_detach, K1, K2, K3), "float32")
C = T.match_sparse_buffer(c, (I, K1, K2, K3), "float32")
with T.iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
with T.sp_iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
with T.init():
C[i, k1, k2, k3] = 0.0
C[i, k1, k2, k3] = C[i, k1, k2, k3] + A[i, j] * B[j, k1, k2, k3]
Expand Down
2 changes: 1 addition & 1 deletion examples/spmm/bench_spmm_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def csrmm(
A = T.match_sparse_buffer(a, (I, J), "float32")
B = T.match_sparse_buffer(b, (J_detach, K1, K2, K3), "float32")
C = T.match_sparse_buffer(c, (I, K1, K2, K3), "float32")
with T.iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
with T.sp_iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
with T.init():
C[i, k1, k2, k3] = T.float32(0)
C[i, k1, k2, k3] = C[i, k1, k2, k3] + A[i, j] * B[j, k1, k2, k3]
Expand Down
2 changes: 1 addition & 1 deletion examples/spmm/bench_tc_spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def tcspmm(
A = T.match_sparse_buffer(a, [IO, JO, II, JI], "float16")
B = T.match_sparse_buffer(b, [J, F], "float16")
C = T.match_sparse_buffer(c, [IO, II, F], "float16")
with T.iter([IO, JO, II, JI, F], "SRSRS", "tcspmm") as [io, jo, ii, ji, f]:
with T.sp_iter([IO, JO, II, JI, F], "SRSRS", "tcspmm") as [io, jo, ii, ji, f]:
with T.init():
C[io, ii, f] = T.float16(0)
C[io, ii, f] = C[io, ii, f] + A[io, jo, ii, ji] * B[ji, f]
Expand Down
74 changes: 74 additions & 0 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,80 @@ def enter_scope(
class SparseIteration(WithScopeHandler):
"""With scope handler of SparseIteration"""

def __init__(self):
def sp_iter(axes: List, iter_types: str, name_hint: str = "", span: Optional[Span] = None):
assert (
self.node and self.context and self.body
), "call 'exit_scope' before 'enter_scope'"
block_info = self.context.block_info_stack[-1]
axes = flatten_axes(axes)

if len(axes) != len(self.sp_iters):
self.context.report_error(
"Inconsistent number of sparse iteration variable names, "
+ f"there are {len(axes)} iterators but {len(self.sp_iters)} names. "
+ "The number of sparse iteration variable names should match the number of iterators.",
self.node.span,
)
if len(axes) != len(iter_types):
self.context.report_error(
"Inconsistent number of sparse iteration variable types, "
+ f"there are {len(axes)} iterators but {len(iter_types)} types. "
+ "The number of sparse iteration variable types should match the number of iterators.",
self.node.span,
)

sp_iters: List[SpIterVar] = []
for i, axis in enumerate(axes):
is_reduction = True if iter_types[i] == "R" else False
if isinstance(axis, FusedAxis):
length = axis.group[axis.index].length
sp_iters.append(
SpIterVar(
self.sp_iters[i],
is_reduction,
axis,
)
)

block = tvm.tir.SparseIteration(
sp_iters,
name_hint,
self.body,
block_info.init,
block_info.annotations,
span,
)
return block

super().__init__(func=sp_iter, concise_scope=False, def_symbol=True)
self.sp_iters = None

def enter_scope(
self,
node: synr.ast.Node,
context: ContextMaintainer,
arg_list: List[Any],
span: synr.ast.Span,
):
# define sparse iteration variables
assert isinstance(
node, synr.ast.With
), f"SparseIteration ScopeHandler expected to work on synr.ast.With but got {type(node)}"

vars = WithScopeHandler.get_optional_vars(node, context)
axes = flatten_axes(arg_list[0])
self.sp_iters = [
tvm.tir.Var(var.id.name, dtype=axis.idtype) for var, axis in zip(vars, axes)
]
for sp_iter in self.sp_iters:
context.update_symbol(sp_iter.name, sp_iter, node)


@register
class LegacySparseIteration(WithScopeHandler):
"""Legacy scope handler of SparseIteration to ensure the compatibility with previous T.iter interface."""

def __init__(self):
def iter(axes: List, iter_types: str, name_hint: str = "", span: Optional[Span] = None):
assert (
Expand Down
40 changes: 38 additions & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,44 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None):


@register
class IterAttr(SpecialStmt):
"""Special function iter_attr({attr_key: attr_value})
class SparseIterationAttr(SpecialStmt):
"""Special function sp_iter_attr({attr_key: attr_value})
Example
-------
.. code-block:: python
T.sp_iter_attr({"preprocess": True})
"""

def __init__(self):
def sp_iter_attr(attrs: Mapping[str, Object], span: Span = None):
assert self.context, "call 'exit_scope' before 'enter_scope'"
block_scope = self.context.current_block_scope()
if block_scope is None:
self.context.report_error(
"Expected to declare block annotations inside a block.",
span,
)
if block_scope.annotations is not None:
self.context.report_error(
"Duplicate block annotations declaration, "
+ "previous one is "
+ str(block_scope.annotations),
span,
)
attrs = {
key: String(val) if isinstance(val, str) else val for key, val in attrs.items()
}
block_scope.annotations = attrs

super().__init__(sp_iter_attr, def_symbol=False)


@register
class LegacySparseIterationAttr(SpecialStmt):
"""Legacy Sparse Iteration Attribution API to ensure compatibility.
iter_attr({attr_key: attr_value})
Example
-------
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {

Doc TVMScriptPrinter::PrintSparseIterationName(const SparseIterationNode* op) {
Doc doc;
doc << "with " << tir_prefix_ << ".iter([";
doc << "with " << tir_prefix_ << ".sp_iter([";

int n_iter = static_cast<int>(op->sp_iter_vars.size());

Expand Down
Loading

0 comments on commit ca59cbe

Please sign in to comment.