Skip to content

Commit

Permalink
apply eric's patch
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Aug 2, 2024
1 parent 0abe71d commit a07aa4f
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 10 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
TopologicalSort,
UpdateParamStructInfo,
UpdateVDevice,
VMBuiltinLower,
VMShapeLower,
dataflowblock_pass,
function_pass,
Expand Down
6 changes: 0 additions & 6 deletions src/relax/op/memory/view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,6 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) {

TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView);

Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
// No-op. View is lowered during the LowerBuiltinView pass.
return call;
}

Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) {
Expr data = call->args[0];
Expr shape = call->args[1];
Expand Down Expand Up @@ -357,7 +352,6 @@ TVM_REGISTER_OP("relax.memory.view")
"The view's byte offset, relative to the input tensor's byte offset.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoView)
.set_attr<FLegalize>("FLegalize", LegalizeView)
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinView);

Expand Down
4 changes: 3 additions & 1 deletion tests/python/relax/test_op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")):


def test_legalize_is_no_op():
"""R.memory.view is not legalized until LowerRuntimeBuiltin"""

@I.ir_module
class Before:
@R.function
Expand All @@ -462,7 +464,7 @@ def main(A: R.Tensor([4096], "float32")):

Expected = Before

After = tvm.relax.transform.LowerRuntimeBuiltin()(Before)
After = tvm.relax.transform.LegalizeOps()(Before)
tvm.ir.assert_structural_equal(Expected, After)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32
tvm.ir.assert_structural_equal(mod, Expected)
mod = relax.transform.LowerAllocTensor()(mod)
mod = relax.transform.KillAfterLastUse()(mod)
mod = relax.transform.VMBuiltinLower()(mod)
mod = relax.transform.LowerRuntimeBuiltin()(mod)
tvm.ir.assert_structural_equal(mod, ExpectedLowered)


Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_vm_builtin_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
gv0 = alloc
return gv0

After = relax.transform.VMBuiltinLower()(Before)
After = relax.transform.LowerRuntimeBuiltin()(Before)
tvm.ir.assert_structural_equal(Expected, After)


Expand All @@ -79,7 +79,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
return gv0

with pytest.raises(tvm.TVMError):
relax.transform.VMBuiltinLower()(Before)
relax.transform.LowerRuntimeBuiltin()(Before)


if __name__ == "__main__":
Expand Down

0 comments on commit a07aa4f

Please sign in to comment.