Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement R.ensure_zero_offset and update memory planning for R.view #17145

Merged
merged 12 commits into from
Aug 6, 2024
2 changes: 1 addition & 1 deletion include/tvm/relax/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace transform {
*
* \return The Pass.
*/
TVM_DLL Pass VMBuiltinLower();
TVM_DLL Pass LowerRuntimeBuiltin();

/*!
* \brief Lower the shape expression in relax to VM shape heap and TIR functions.
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call ca
*/
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

/*! \brief The function type of a function to lower the runtime builtin.
*
* A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`.
*
* \param bb The BlockBuilder context.
* \param call The call to be lowered.
*/
using FLowerBuiltin = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

/*!
* \brief Gradient for a specific op.
*
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ class TVM_DLL DeviceAPI {
return device_type != kDLCPU && device_type != kDLMicroDev;
}

/*!
* \brief Whether pointer arithmetics on a device owned pointer may be performed on the host.
*/
virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; }

protected:
/*!
* \brief copy data from one place to another
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Relax memory primitives."""

from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
from .view import view
from .view import view, ensure_zero_offset
17 changes: 17 additions & 0 deletions python/tvm/relax/op/memory/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls):
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)

return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore


def ensure_zero_offset(data: Expr) -> Expr:
"""
Ensure the tensor has elem_offset == 0. A copy will be made if necessary.

Parameters
----------
data : relax.Expr
The input tensor

Results
-------
result : relax.Expr
The tensor with elem_offset == 0
"""
return _ffi_api.ensure_zero_offset(data) # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
transform.RewriteCUDAGraph(),
transform.LowerAllocTensor(),
transform.KillAfterLastUse(),
transform.VMBuiltinLower(),
transform.LowerRuntimeBuiltin(),
transform.ComputePrimValue(),
transform.VMShapeLower(),
transform.AttachGlobalSymbol(),
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
LegalizeOps,
LiftTransformParams,
LowerAllocTensor,
LowerRuntimeBuiltin,
MergeCompositeFunctions,
MetaScheduleApplyDatabase,
MetaScheduleTuneIRMod,
Expand All @@ -64,8 +65,8 @@
PatternCheckContext,
RealizeVDevice,
RemovePurityChecking,
RemoveUnusedParameters,
RemoveUnusedOutputs,
RemoveUnusedParameters,
ReorderPermuteDimsAfterConcat,
ReorderTakeAfterMatmul,
RewriteCUDAGraph,
Expand All @@ -84,14 +85,13 @@
function_pass,
)

from .attach_external_modules import AttachExternModules
from .fast_math import FastMathTransform
from .ipc_allreduce_rewrite import IPCAllReduceRewrite
from .lazy_transform_params import LazyTransformParams
from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
from .fast_math import FastMathTransform
from .fuse_transpose_matmul import FuseTransposeMatmul
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
from .attach_external_modules import AttachExternModules

# Import to register the legalization functions.
from . import legalize_ops, tuning_api
17 changes: 16 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import inspect
import types
import warnings
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np # type: ignore
Expand Down Expand Up @@ -586,14 +587,28 @@ def ComputePrimValue() -> tvm.ir.transform.Pass:
return _ffi_api.ComputePrimValue() # type: ignore


def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass:
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
"""Lowering generic intrinsic to VM intrinsics.

Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.LowerRuntimeBuiltin() # type: ignore


def VMBuiltinLower() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.

Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMBuiltinLower() # type: ignore
warnings.warn(
"tvm.relax.transform.VMBuiltinLower has been renamed to 'LowerRuntimeBuiltin'. "
"This wrapper is for backwards compatibility, and will be removed in a later update."
)
return _ffi_api.LowerRuntimeBuiltin() # type: ignore


def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* under the License.
*/
/*!
* \file src/relax/backend/vm/vm_builtin_lower.cc
* \file src/relax/backend/vm/lower_runtime_builtin.cc
* \brief Lowers most builtin functions and packed calls.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/op.h>
Expand All @@ -33,11 +34,12 @@ namespace relax {

// This pass lowers most ops to VM specific builtins.
// TODO(relax-team): revisit after PrimValue.
class VMBuiltinLowerMutator : public ExprMutator {
class LowerRuntimeBuiltinMutator : public ExprMutator {
public:
using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call_node) final {
static const auto& lower_builtin_fmap = Op::GetAttrMap<FLowerBuiltin>("FLowerBuiltin");
// post-order mutation
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));

Expand All @@ -64,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator {
return MakeMemAllocTensor(call);
} else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) {
return MakeMemKillObject(call);
} else {
return call;
} else if (const auto* op_node = call->op.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
if (lower_builtin_fmap.count(op)) {
return lower_builtin_fmap[op](builder_, call);
}
}
return call;
}

Expr MakeMemAllocStorage(const Call& call) {
Expand Down Expand Up @@ -210,17 +216,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};

Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }
Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); }

namespace transform {

Pass VMBuiltinLower() {
Pass LowerRuntimeBuiltin() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(VMBuiltinLower(f)); };
return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(LowerRuntimeBuiltin(f));
};
return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {});
}

TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin);

} // namespace transform
} // namespace relax
Expand Down
35 changes: 32 additions & 3 deletions src/relax/op/memory/view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) {

TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView);

Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) {
Expr data = call->args[0];
Expr shape = call->args[1];
Expr dtype = call->args[2];
Expand Down Expand Up @@ -352,8 +352,37 @@ TVM_REGISTER_OP("relax.memory.view")
"The view's byte offset, relative to the input tensor's byte offset.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoView)
.set_attr<FLegalize>("FLegalize", LegalizeView)
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinView);

Expr ensure_zero_offset(const Expr& x) {
static const Op& op = Op::Get("relax.memory.ensure_zero_offset");
return Call(op, {x});
}

TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset);

StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << call->op << " should receive 1 argument, "
<< "but received " << call->args);
}
return GetStructInfo(call->args[0]);
}

Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) {
const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"};
return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)});
}

TVM_REGISTER_OP("relax.memory.ensure_zero_offset")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEnsureZeroOffset)
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinEnsureZeroOffset);

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relax/op/memory/view.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace relax {
/*! \brief View a tensor with different properties. */
Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr> relative_byte_offset);

/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */
Expr ensure_aligned(const Expr& x);

} // namespace relax
} // namespace tvm

Expand Down
13 changes: 9 additions & 4 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,13 @@ class TokenAllocator1D {
std::vector<StorageToken> full_pool_;
};

/*! \brief Check if the input op is "relax.reshape". */
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
/*! \brief Check if the input op is a memory op that may return the same buffer. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you on the updated docstring. As I'm looking at it, we may want to add this as another operator attribute (e.g. .set_attr<Bool>("ReturnMayAliasArgument", Bool(true))), but that could be a follow-up PR instead.

bool IsInplaceMemoryOp(const Expr& op) {
static const Op& reshape_op = Op::Get("relax.reshape");
static const Op& view_op = Op::Get("relax.memory.view");
static const Op& ensure_zero_offset_op = Op::Get("relax.memory.ensure_zero_offset");
return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_zero_offset_op);
}

/*! \brief The base class for the storage allocation visitor. */
class StorageAllocatorBaseVisitor : public ExprVisitor {
Expand Down Expand Up @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
// Create a storage token for builtin alloc_tensor.
this->CreateToken(call);
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
// Reuse the input's token for builtin reshape.
SetTokens(call, GetTokens(call->args[0]));
return;
Expand Down Expand Up @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
block_tokens.push_back(new_token.get());
}
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
Tokens tokens = GetTokens(call->args[0]);
ICHECK(!tokens.IsNested());
if (tokens.IsLeaf()) {
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/cpu_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;

bool SupportsDevicePointerArithmeticsOnHost() final { return true; }

static CPUDeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data);
}

bool SupportsDevicePointerArithmeticsOnHost() final { return true; }

static CUDADeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
Expand Down
19 changes: 19 additions & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,25 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
return ShapeTuple(out_shape);
});

TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) {
if (data->byte_offset == 0) {
return data;
}
auto* device_api = DeviceAPI::Get(data->device);
if (device_api->SupportsDevicePointerArithmeticsOnHost() &&
data->byte_offset % tvm::runtime::kAllocAlignment == 0) {
DLManagedTensor* dl_tensor = data.ToDLPack();
dl_tensor->dl_tensor.data =
reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset;
dl_tensor->dl_tensor.byte_offset = 0;
return NDArray::FromDLPack(dl_tensor);
} else {
auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device);
new_array.CopyFrom(data);
return new_array;
}
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Expand Down
Loading
Loading