From e2098bca9980312e01c59e890bb190afafe1b135 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Jul 2024 15:52:48 -0700 Subject: [PATCH] [Relax] Implement R.ensure_aligned and update memory planning for R.view --- python/tvm/relax/op/memory/__init__.py | 2 +- python/tvm/relax/op/memory/view.py | 17 +++ src/relax/backend/vm/vm_builtin_lower.cc | 20 ++++ src/relax/op/memory/view.cc | 34 +++++- src/relax/op/memory/view.h | 3 + .../transform/static_plan_block_memory.cc | 13 ++- src/runtime/relax_vm/builtin.cc | 13 +++ tests/python/relax/test_op_view.py | 105 +++++++----------- ...test_transform_static_plan_block_memory.py | 55 +++++++++ 9 files changed, 185 insertions(+), 77 deletions(-) diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 422c5d2e1f53a..2ae1b676e035c 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,4 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor -from .view import view +from .view import view, ensure_aligned diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 0c3d8a03b2dd7..233d07f6c9b71 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls): relative_byte_offset = _normalize(relative_byte_offset, PrimValue) return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore + + +def ensure_aligned(data: Expr) -> Expr: + """ + Ensure the tensor has elem_offset == 0. A copy will be made if necessary. + + Parameters + ---------- + data : relax.Expr + The input tensor + + Results + ------- + result : relax.Expr + The aligned tensor + """ + return _ffi_api.ensure_aligned(data) # type: ignore diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index 887998d004c72..961aa9b600f8b 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -47,6 +47,10 @@ class VMBuiltinLowerMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); + } else if (call->op == view_op_) { + return View(call); + } else if (call->op == ensure_aligned_op_) { + return EnsureAligned(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -124,6 +128,19 @@ class VMBuiltinLowerMutator : public ExprMutator { } } + Expr View(const Call& view_node) { + StructInfoDeriveFunc infer_sinfo_env_func; + infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); + auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo}); + } + + Expr EnsureAligned(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + Expr ShapeOf(const Call& call_node) { ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); @@ -188,6 +205,8 @@ class VMBuiltinLowerMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); + const Op& view_op_ = Op::Get("relax.memory.view"); + const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -208,6 +227,7 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; + const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"}; }; Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index e7634c7edfceb..d43cc01838aea 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -334,13 +334,12 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { relative_byte_offset = relax::PrimValue::Int64(0); } - StructInfoDeriveFunc infer_sinfo_env_func; - infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); - auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); - - ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + if (shape.same_as(call->args[1]) && dtype.same_as(call->args[2]) && + relative_byte_offset.same_as(call->args[3])) { + return call; + } - return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); + return Call(call->op, {data, shape, dtype, relative_byte_offset}); } TVM_REGISTER_OP("relax.memory.view") @@ -355,5 +354,28 @@ TVM_REGISTER_OP("relax.memory.view") .set_attr("FLegalize", LegalizeView) .set_attr("FPurity", Bool(true)); +Expr ensure_aligned(const Expr& x) { + static const Op& op = Op::Get("relax.memory.ensure_aligned"); + return Call(op, {x}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned); + +StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " should receive 1 argument, " + << "but received " << call->args); + } + return GetStructInfo(call->args[0]); +} + +TVM_REGISTER_OP("relax.memory.ensure_aligned") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FInferStructInfo", InferStructInfoEnsureAligned) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h index bc8002fa5b697..77ec7e9833cc9 100644 --- a/src/relax/op/memory/view.h +++ b/src/relax/op/memory/view.h @@ -32,6 +32,9 @@ namespace relax { /*! \brief View a tensor with different properties. */ Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); +/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */ +Expr ensure_aligned(const Expr& x); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2b16d8650906a..2922de6dcc7ee 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -286,8 +286,13 @@ class TokenAllocator1D { std::vector full_pool_; }; -/*! \brief Check if the input op is "relax.reshape". */ -bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } +/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */ +bool IsInplaceMemoryOp(const Expr& op) { + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& view_op = Op::Get("relax.memory.view"); + static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned"); + return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op); +} /*! \brief The base class for the storage allocation visitor. */ class StorageAllocatorBaseVisitor : public ExprVisitor { @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // Create a storage token for builtin alloc_tensor. this->CreateToken(call); return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { // Reuse the input's token for builtin reshape. SetTokens(call, GetTokens(call->args[0])); return; @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { block_tokens.push_back(new_token.get()); } return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { Tokens tokens = GetTokens(call->args[0]); ICHECK(!tokens.IsNested()); if (tokens.IsLeaf()) { diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 2af31f1d4021f..83b016446548a 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -545,6 +545,19 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ShapeTuple(out_shape); }); +TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) { + if (data->byte_offset == 0) { + return data; + } + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + // For platforms that does not support pointer arithmetic, we need to copy the data to a new + // buffer. + return NDArray::FromDLPack(dl_tensor); +}); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 2433821c2abd3..1e21612f9fff4 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -483,18 +483,7 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([64, 64]), - R.dtype("float32"), - R.prim_value(0), - ) + B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) return B After = tvm.relax.transform.LegalizeOps()(Before) @@ -515,18 +504,7 @@ def main(A: R.Tensor(dtype="float32")): class Expected: @R.function def main(A: R.Tensor(dtype="float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([64, 64]), - R.dtype("float32"), - R.prim_value(0), - ) + B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0) return B After = tvm.relax.transform.LegalizeOps()(Before) @@ -545,17 +523,8 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([4096]), - R.dtype("int32"), - R.prim_value(0), + B = R.memory.view( + A, dtype=R.dtype("int32"), shape=R.shape([4096]), relative_byte_offset=0 ) return B @@ -575,17 +544,8 @@ def main(A: R.Tensor([4096], "float32")): class Expected: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( - A, - R.shape([4096]), - R.dtype("float32"), - R.prim_value(0), + B = R.memory.view( + A, relative_byte_offset=R.prim_value(0), shape=R.shape([4096]), dtype="float32" ) return B @@ -624,29 +584,17 @@ def main(A: R.Tensor([4096], "uint8")): class Expected: @R.function def main(A: R.Tensor([4096], "uint8")): - B = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( + B = R.memory.view( A, - R.shape([512]), - R.dtype("int32"), - R.prim_value(0), + shape=R.shape([512]), + dtype=R.dtype("int32"), + relative_byte_offset=R.prim_value(0), ) - C = R.ExternFunc( - "runtime.TVMArrayCreateView", - R.Callable( - derive_func="tvm.relax.struct_info.infer_view_sinfo", - purity=True, - ), - )( + C = R.memory.view( A, - R.shape([16, 64]), - R.dtype("float16"), - R.prim_value(2048), + shape=R.shape([16, 64]), + dtype=R.dtype("float16"), + relative_byte_offset=R.prim_value(2048), ) return (B, C) @@ -772,5 +720,30 @@ def main(A: R.Tensor([4096], "uint8")): tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1]) +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_view_with_new_byte_offset_ensure_aligned(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.memory.view( + A, + shape=R.shape([16, 64]), + relative_byte_offset=32 * 64 * 4, + ) + C = R.memory.ensure_aligned(B) + return C + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.random([4096]).astype("float32") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.reshape(64, 64)[32:48, :] + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 63f422d4cfbe7..61f8d6600568a 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1449,5 +1449,60 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_view(): + @I.ir_module + class Before: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main(): + cls = Before + x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0) + x1 = R.memory.view(x, [128], "float32", 0) + x2 = R.memory.ensure_aligned(x1) + y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(x2, y) + z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(y, z) + return z + + @I.ir_module + class Expected: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main() -> R.Tensor((128,), dtype="float32"): + cls = Module + storage: R.Object = R.memory.alloc_storage( + R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32") + ) + x1: R.Tensor((128,), dtype="float32") = R.memory.view( + x, R.shape([128]), R.dtype("float32"), R.prim_value(0) + ) + x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1) + storage1: R.Object = R.memory.alloc_storage( + R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([128]), R.dtype("float32") + ) + cls.tir_exp(x2, y) + z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.tir_exp(y, z) + return z + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main()