Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement R.ensure_zero_offset and update memory planning for R.view #17145

Merged
merged 12 commits into from
Aug 6, 2024
Next Next commit
[Relax] Implement R.ensure_aligned and update memory planning for R.view
  • Loading branch information
vinx13 committed Aug 2, 2024
commit 2f5d03603e149784d7fff476e6e0704c85f17fc1
2 changes: 1 addition & 1 deletion python/tvm/relax/op/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Relax memory primitives."""

from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
from .view import view
from .view import view, ensure_aligned
17 changes: 17 additions & 0 deletions python/tvm/relax/op/memory/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls):
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)

return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore


def ensure_aligned(data: Expr) -> Expr:
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
"""
Ensure the tensor has elem_offset == 0. A copy will be made if necessary.

Parameters
----------
data : relax.Expr
The input tensor

Results
-------
result : relax.Expr
The aligned tensor
"""
return _ffi_api.ensure_aligned(data) # type: ignore
20 changes: 20 additions & 0 deletions src/relax/backend/vm/vm_builtin_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == view_op_) {
return View(call);
} else if (call->op == ensure_aligned_op_) {
return EnsureAligned(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
Expand Down Expand Up @@ -124,6 +128,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
}
}

Expr View(const Call& view_node) {
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
StructInfoDeriveFunc infer_sinfo_env_func;
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);
ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo});
}

Expr EnsureAligned(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeOf(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
ICHECK(call_node->struct_info_.defined());
Expand Down Expand Up @@ -188,6 +205,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& view_op_ = Op::Get("relax.memory.view");
const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
Expand All @@ -208,6 +227,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"};
};

Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }
Expand Down
34 changes: 28 additions & 6 deletions src/relax/op/memory/view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,12 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
relative_byte_offset = relax::PrimValue::Int64(0);
}

StructInfoDeriveFunc infer_sinfo_env_func;
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);

ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
if (shape.same_as(call->args[1]) && dtype.same_as(call->args[2]) &&
relative_byte_offset.same_as(call->args[3])) {
return call;
}

return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset});
return Call(call->op, {data, shape, dtype, relative_byte_offset});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change means that R.memory.view is still present in the output of LegalizeOps, but a legalization function should replace the operator with a lowered form.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only does the inference of void type args and leave the lowering to the later pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to do the inference of the shape/dtype prior to lowering, because it could result in unexpected StructInfo inference later on.

Suppose we have R.memory.view(arg, shape=[16]). This returns a view into the first 16 elements of arg, without changing the dtype. If an IRModule pass updates the datatype of arg, then that new datatype should also propagate to the view. However, legalizing it to R.memory.view(arg, shape=[16], dtype="float16") would return a view into arg, interpreting the first 32 bytes as if they were "float16". Now, if an IRModule pass updates the datatype of arg, the view would still be "float16". To avoid this issue, the unknown arguments shouldn't be filled in until the lowering is about to occur.

What if we were to remove .set_attr<FLegalize> altogether, and only have .set_attr<FLowerBuiltin>? That way, we preserve the R.memory.view as-is until it is ready to be lowered. The LegalizeOps pass would then be a no-op for R.memory.view, and only the LowerRuntimeBuiltin pass would change it at all.

}

TVM_REGISTER_OP("relax.memory.view")
Expand All @@ -355,5 +354,28 @@ TVM_REGISTER_OP("relax.memory.view")
.set_attr<FLegalize>("FLegalize", LegalizeView)
.set_attr<Bool>("FPurity", Bool(true));

Expr ensure_aligned(const Expr& x) {
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
static const Op& op = Op::Get("relax.memory.ensure_aligned");
return Call(op, {x});
}

TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned);

StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << call->op << " should receive 1 argument, "
<< "but received " << call->args);
}
return GetStructInfo(call->args[0]);
}

TVM_REGISTER_OP("relax.memory.ensure_aligned")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEnsureAligned)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relax/op/memory/view.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace relax {
/*! \brief View a tensor with different properties. */
Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr> relative_byte_offset);

/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */
Expr ensure_aligned(const Expr& x);

} // namespace relax
} // namespace tvm

Expand Down
13 changes: 9 additions & 4 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,13 @@ class TokenAllocator1D {
std::vector<StorageToken> full_pool_;
};

/*! \brief Check if the input op is "relax.reshape". */
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
bool IsInplaceMemoryOp(const Expr& op) {
static const Op& reshape_op = Op::Get("relax.reshape");
static const Op& view_op = Op::Get("relax.memory.view");
static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned");
return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op);
}

/*! \brief The base class for the storage allocation visitor. */
class StorageAllocatorBaseVisitor : public ExprVisitor {
Expand Down Expand Up @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
// Create a storage token for builtin alloc_tensor.
this->CreateToken(call);
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
// Reuse the input's token for builtin reshape.
SetTokens(call, GetTokens(call->args[0]));
return;
Expand Down Expand Up @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
block_tokens.push_back(new_token.get());
}
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
Tokens tokens = GetTokens(call->args[0]);
ICHECK(!tokens.IsNested());
if (tokens.IsLeaf()) {
Expand Down
13 changes: 13 additions & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,19 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
return ShapeTuple(out_shape);
});

TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) {
if (data->byte_offset == 0) {
return data;
}
DLManagedTensor* dl_tensor = data.ToDLPack();
dl_tensor->dl_tensor.data =
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset;
dl_tensor->dl_tensor.byte_offset = 0;
// For platforms that does not support pointer arithmetic, we need to copy the data to a new
// buffer.
return NDArray::FromDLPack(dl_tensor);
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Expand Down
105 changes: 39 additions & 66 deletions tests/python/relax/test_op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,7 @@ def main(A: R.Tensor([4096], "float32")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([64, 64]),
R.dtype("float32"),
R.prim_value(0),
)
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
return B

After = tvm.relax.transform.LegalizeOps()(Before)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After replacing the .set_attr<FLegalize> for R.memory.view with .set_attr<FLowerBuiltin>, the changes to these unit tests can be reverted. Instead, any use of LegalizeOps in the unit tests would instead call LowerRuntimeBuiltin.

Expand All @@ -515,18 +504,7 @@ def main(A: R.Tensor(dtype="float32")):
class Expected:
@R.function
def main(A: R.Tensor(dtype="float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([64, 64]),
R.dtype("float32"),
R.prim_value(0),
)
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
return B

After = tvm.relax.transform.LegalizeOps()(Before)
Expand All @@ -545,17 +523,8 @@ def main(A: R.Tensor([4096], "float32")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([4096]),
R.dtype("int32"),
R.prim_value(0),
B = R.memory.view(
A, dtype=R.dtype("int32"), shape=R.shape([4096]), relative_byte_offset=0
)
return B

Expand All @@ -575,17 +544,8 @@ def main(A: R.Tensor([4096], "float32")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([4096]),
R.dtype("float32"),
R.prim_value(0),
B = R.memory.view(
A, relative_byte_offset=R.prim_value(0), shape=R.shape([4096]), dtype="float32"
)
return B

Expand Down Expand Up @@ -624,29 +584,17 @@ def main(A: R.Tensor([4096], "uint8")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "uint8")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
B = R.memory.view(
A,
R.shape([512]),
R.dtype("int32"),
R.prim_value(0),
shape=R.shape([512]),
dtype=R.dtype("int32"),
relative_byte_offset=R.prim_value(0),
)
C = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
C = R.memory.view(
A,
R.shape([16, 64]),
R.dtype("float16"),
R.prim_value(2048),
shape=R.shape([16, 64]),
dtype=R.dtype("float16"),
relative_byte_offset=R.prim_value(2048),
)
return (B, C)

Expand Down Expand Up @@ -772,5 +720,30 @@ def main(A: R.Tensor([4096], "uint8")):
tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1])


@tvm.testing.parametrize_targets("llvm", "cuda")
def test_execute_view_with_new_byte_offset_ensure_aligned(target, dev):
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.memory.view(
A,
shape=R.shape([16, 64]),
relative_byte_offset=32 * 64 * 4,
)
C = R.memory.ensure_aligned(B)
return C

built = tvm.relax.build(Module, target=target)
vm = tvm.relax.VirtualMachine(built, device=dev)

np_input = np.random.random([4096]).astype("float32")
tvm_input = tvm.nd.array(np_input, dev)
tvm_output = vm["main"](tvm_input)
np_expected = np_input.reshape(64, 64)[32:48, :]

tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)


if __name__ == "__main__":
tvm.testing.main()
55 changes: 55 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,5 +1449,60 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


def test_view():
@I.ir_module
class Before:
@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main():
cls = Before
x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0)
x1 = R.memory.view(x, [128], "float32", 0)
x2 = R.memory.ensure_aligned(x1)
y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0)
cls.tir_exp(x2, y)
z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0)
cls.tir_exp(y, z)
return z

@I.ir_module
class Expected:
@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main() -> R.Tensor((128,), dtype="float32"):
cls = Expected
storage: R.Object = R.memory.alloc_storage(
R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor(
storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32")
)
x1: R.Tensor((128,), dtype="float32") = R.memory.view(
x, R.shape([128]), R.dtype("float32"), R.prim_value(0)
)
x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1)
storage1: R.Object = R.memory.alloc_storage(
R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
)
cls.tir_exp(x2, y)
z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor(
R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global")
)
cls.tir_exp(y, z)
return z

after = relax.transform.StaticPlanBlockMemory()(Before)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()