Skip to content

Commit

Permalink
[xla:cpu] Make HandleCustomCall support typed FFI
Browse files Browse the repository at this point in the history
Fixes #10056
Co-authored-by: pparuzel <paruzelp@google.com>
Co-authored-by: Adam-Banas <adambanas@google.com>
PiperOrigin-RevId: 615387700
  • Loading branch information
heinsaar authored and tensorflower-gardener committed Apr 26, 2024
1 parent 6168423 commit cac501c
Show file tree
Hide file tree
Showing 12 changed files with 462 additions and 14 deletions.
32 changes: 32 additions & 0 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ filegroup(
"runtime_matmul_f64.cc",
"runtime_matmul_s32.cc",
"runtime_fork_join.cc",
"runtime_handle_ffi_call.cc",
],
visibility = internal_visibility([":friends"]),
)
Expand Down Expand Up @@ -138,6 +139,7 @@ filegroup(
"runtime_fork_join.h",
"runtime_lightweight_check.h",
"runtime_matmul.h",
"runtime_handle_ffi_call.h",
],
visibility = internal_visibility([":friends"]),
)
Expand Down Expand Up @@ -543,6 +545,7 @@ cc_library(
":runtime_fft",
":runtime_fork_join",
":runtime_fp16",
":runtime_handle_ffi_call",
":runtime_key_value_sort",
":runtime_matmul",
":runtime_matmul_acl",
Expand Down Expand Up @@ -702,6 +705,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -1235,6 +1239,34 @@ cc_library(
],
)

cc_library(
name = "runtime_handle_ffi_call",
srcs = ["runtime_handle_ffi_call.cc"],
hdrs = ["runtime_handle_ffi_call.h"],
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
#"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:IR",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/ffi:call_frame",
"//xla/ffi:ffi_api",
"//xla/service:custom_call_status_public_headers",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "cpu_runtime_test",
srcs = ["cpu_runtime_test.cc"],
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ extern const char* const kOneDnnLayerNormSymbolName =
"__xla_cpu_runtime_OneDnnLayerNorm";
extern const char* const kOneDnnMatMulReorderSymbolName =
"__xla_cpu_runtime_OneDnnMatMulReorder";
extern const char* const kHandleFfiCallSymbolName =
"__xla_cpu_runtime_HandleFfiCall";

namespace {

Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/cpu/cpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ extern const char* const kOneDnnMatMulSymbolName;
extern const char* const kOneDnnSoftmaxSymbolName;
extern const char* const kOneDnnLayerNormSymbolName;
extern const char* const kOneDnnMatMulReorderSymbolName;
extern const char* const kHandleFfiCallSymbolName;

// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
Expand Down
124 changes: 120 additions & 4 deletions third_party/xla/xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ limitations under the License.
#include <stdint.h>

#include <algorithm>
#include <cstddef>
#include <iterator>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

Expand All @@ -33,6 +35,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -2814,30 +2817,35 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
}
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
}
auto* output_address_arg = GetEmittedValueFor(custom_call);
auto* output_address = GetEmittedValueFor(custom_call);

auto typed_custom_call = Cast<HloCustomCallInstruction>(custom_call);
switch (typed_custom_call->api_version()) {
case CustomCallApiVersion::API_VERSION_ORIGINAL:
EmitCallToFunc(custom_call->custom_call_target(),
{output_address_arg, operands_alloca}, b_.getVoidTy());
{output_address, operands_alloca}, b_.getVoidTy());
break;
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
EmitCallToFunc(custom_call->custom_call_target(),
{output_address_arg, operands_alloca, GetStatusArgument()},
{output_address, operands_alloca, GetStatusArgument()},
b_.getVoidTy());
EmitEarlyReturnIfErrorStatus();
break;
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: {
absl::string_view opaque = typed_custom_call->opaque();
EmitCallToFunc(custom_call->custom_call_target(),
{output_address_arg, operands_alloca,
{output_address, operands_alloca,
b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(opaque)),
b_.getInt64(opaque.size()), GetStatusArgument()},
b_.getVoidTy());
EmitEarlyReturnIfErrorStatus();
break;
}
case CustomCallApiVersion::API_VERSION_TYPED_FFI: {
EmitCallToFfi(typed_custom_call, output_address, operands_alloca);
EmitEarlyReturnIfErrorStatus();
break;
}
default:
return Internal(
"Unknown custom-call API version enum value: %d (%s)",
Expand Down Expand Up @@ -3083,6 +3091,114 @@ llvm::Value* IrEmitter::EmitCallToFunc(
return b_.CreateCall(func, arguments);
}

template <typename T>
static const Shape& GetShape(T&& arg) {
if constexpr (std::is_convertible_v<absl::remove_cvref_t<decltype(arg)>,
Shape>) {
return arg; // convertible to shape, so just return
} else {
return arg->shape();
}
};

template <typename T>
llvm::AllocaInst* IrEmitter::StoreTypes(std::string_view alloca_name,
T&& args) {
auto* types_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt32Ty(), b_.getInt64(args.size()), alloca_name, &b_);

for (int64_t i = 0; i < args.size(); ++i) {
llvm::Value* slot_in_types_alloca =
ConstInBoundsGEP1_32(b_.getInt32Ty(), types_alloca, i);
Store(b_.getInt32(GetShape(args[i]).element_type()), slot_in_types_alloca);
}
return types_alloca;
};

template <typename T>
llvm::Value* IrEmitter::StoreShapes(std::string_view alloca_name, T&& args) {
// Prepare metadata for all buffers
// Shapes metadata is encoded using contiguous flattened dimension values:
// {
// 1: DIMCOUNT_1, DIM_1[1], DIM_1[2], ..., DIM_1[DIMCOUNT_1],
// \______________DIMCOUNT_1 _______________/
// 2: DIMCOUNT_2, DIM_2[1], DIM_2[2], ..., DIM_2[DIMCOUNT_2],
// \______________DIMCOUNT_2 _______________/
// .: ...
// N: DIMCOUNT_N, DIM_N[1], DIM_N[2], ..., DIM_N[DIMCOUNT_N],
// \______________DIMCOUNT_N _______________/
// }
// where N is `operand_count`, and `DIMCOUNT_i` is the # of dimensions
std::size_t total_dims =
absl::c_accumulate(args, int64_t{0}, [](int64_t acc, auto&& arg) {
return acc + GetShape(arg).dimensions().size();
});
int64_t encoded_shapes_size = args.size() // the dimension count identifiers
+ total_dims; // the # of dimension values

llvm::Value* shapes_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b_.getInt64Ty(), b_.getInt64(encoded_shapes_size), alloca_name, &b_);

int64_t slot_id = 0;
for (int64_t i = 0; i < args.size(); ++i) {
auto dims = GetShape(args[i]).dimensions();
llvm::Value* alloca_slot =
ConstInBoundsGEP1_64(b_.getInt64Ty(), shapes_alloca, slot_id++);
// Store the operand count
Store(b_.getInt64(dims.size()), alloca_slot);
// Store the operand dimensions
for (int64_t dim : dims) {
alloca_slot =
ConstInBoundsGEP1_64(b_.getInt64Ty(), shapes_alloca, slot_id++);
Store(b_.getInt64(dim), alloca_slot);
}
}
CHECK_EQ(slot_id, encoded_shapes_size); // All slots are filled
return shapes_alloca;
};

llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call,
llvm::Value* output_address,
llvm::AllocaInst* operands_alloca) {
const auto& operands = absl::MakeSpan(custom_call->operands());
const auto& shape = custom_call->shape();
const auto& result_shapes =
shape.IsTuple() ? shape.tuple_shapes() : std::vector<Shape>({shape});

auto operand_types_alloca = StoreTypes("meta_types_operands", operands);
auto operand_shapes_alloca = StoreShapes("meta_shapes_operands", operands);

auto result_types_alloca = StoreTypes("meta_types_results", result_shapes);
auto result_shapes_alloca = StoreShapes("meta_shapes_results", result_shapes);

const absl::string_view target = custom_call->custom_call_target(); // name
const absl::string_view opaque = custom_call->opaque();

const auto target_ref = llvm_ir::AsStringRef(target);
const auto opaque_ref = llvm_ir::AsStringRef(opaque);

std::vector<llvm::Value*> arguments = {
b_.CreateGlobalStringPtr(target_ref), // target_name_ptr
b_.getInt64(target.size()), // target_name_len
output_address, // output
operands_alloca, // inputs
b_.CreateGlobalStringPtr(opaque_ref), // opaque_str_ptr
b_.getInt64(opaque.size()), // opaque_str_len
GetStatusArgument(), // status_opaque
operand_types_alloca, // operand_types
b_.getInt64(operands.size()), // operand_count
operand_shapes_alloca, // operand_dims
result_types_alloca, // result_types
b_.getInt64(result_shapes.size()), // result_count
result_shapes_alloca, // result_dims
};

return EmitCallToFunc("__xla_cpu_runtime_HandleFfiCall", arguments,
b_.getVoidTy(),
/* does_not_throw = */ false,
/* only_accesses_arg_memory = */ true);
}

void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
int64_t element_count,
PrimitiveType primitive_type,
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ class IrEmitter : public DfsHloVisitorWithDefault,
bool only_accesses_arg_memory = false,
bool only_accesses_inaccessible_mem_or_arg_mem = false);

template <typename T>
llvm::AllocaInst* StoreTypes(std::string_view alloca_name, T&& args);
template <typename T>
llvm::Value* StoreShapes(std::string_view alloca_name, T&& args);

// Emits a call to a proxy that builds an FFI call frame for `custom_call`
llvm::Value* EmitCallToFfi(HloCustomCallInstruction* custom_call,
llvm::Value* output_address,
llvm::AllocaInst* operands_alloca);

// Assignment of the buffers needed by the computation and their shape
// information.
const BufferAssignment& assignment_;
Expand Down
Loading

0 comments on commit cac501c

Please sign in to comment.