diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 5041c3204e7565..8187e0f8c4c9d5 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -111,6 +111,7 @@ filegroup( "runtime_matmul_f64.cc", "runtime_matmul_s32.cc", "runtime_fork_join.cc", + "runtime_handle_ffi_call.cc", ], visibility = internal_visibility([":friends"]), ) @@ -138,6 +139,7 @@ filegroup( "runtime_fork_join.h", "runtime_lightweight_check.h", "runtime_matmul.h", + "runtime_handle_ffi_call.h", ], visibility = internal_visibility([":friends"]), ) @@ -543,6 +545,7 @@ cc_library( ":runtime_fft", ":runtime_fork_join", ":runtime_fp16", + ":runtime_handle_ffi_call", ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_acl", @@ -702,6 +705,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -1235,6 +1239,33 @@ cc_library( ], ) +cc_library( + name = "runtime_handle_ffi_call", + srcs = ["runtime_handle_ffi_call.cc"], + hdrs = ["runtime_handle_ffi_call.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/service:custom_call_status_public_headers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 3e38a06748b808..a11a32f4688cbd 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -166,6 +166,8 @@ extern const char* const kOneDnnLayerNormSymbolName = "__xla_cpu_runtime_OneDnnLayerNorm"; extern const char* const kOneDnnMatMulReorderSymbolName = "__xla_cpu_runtime_OneDnnMatMulReorder"; +extern const char* const kHandleFfiCallSymbolName = + "__xla_cpu_runtime_HandleFfiCall"; namespace { diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index de84b7e7e54fed..f6cb91404ff6bf 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -90,6 +90,7 @@ extern const char* const kOneDnnMatMulSymbolName; extern const char* const kOneDnnSoftmaxSymbolName; extern const char* const kOneDnnLayerNormSymbolName; extern const char* const kOneDnnMatMulReorderSymbolName; +extern const char* const kHandleFfiCallSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 5a2ff9e61c03d1..a2d5a84d2c3468 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -33,6 +35,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -2814,30 +2817,35 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { } llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_); } - auto* output_address_arg = GetEmittedValueFor(custom_call); + auto* output_address = GetEmittedValueFor(custom_call); auto typed_custom_call = Cast(custom_call); switch (typed_custom_call->api_version()) { case CustomCallApiVersion::API_VERSION_ORIGINAL: EmitCallToFunc(custom_call->custom_call_target(), - {output_address_arg, operands_alloca}, b_.getVoidTy()); + {output_address, operands_alloca}, b_.getVoidTy()); break; case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: EmitCallToFunc(custom_call->custom_call_target(), - {output_address_arg, operands_alloca, GetStatusArgument()}, + {output_address, operands_alloca, GetStatusArgument()}, b_.getVoidTy()); EmitEarlyReturnIfErrorStatus(); break; case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: { absl::string_view opaque = typed_custom_call->opaque(); EmitCallToFunc(custom_call->custom_call_target(), - {output_address_arg, operands_alloca, + {output_address, operands_alloca, b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(opaque)), b_.getInt64(opaque.size()), GetStatusArgument()}, b_.getVoidTy()); EmitEarlyReturnIfErrorStatus(); break; } + case CustomCallApiVersion::API_VERSION_TYPED_FFI: { + EmitCallToFfi(typed_custom_call, output_address, operands_alloca); + EmitEarlyReturnIfErrorStatus(); + break; + } default: return Internal( "Unknown custom-call API version enum value: %d (%s)", @@ -3083,6 +3091,114 @@ llvm::Value* IrEmitter::EmitCallToFunc( return b_.CreateCall(func, arguments); } +template +static const Shape& GetShape(T&& arg) { + if constexpr (std::is_convertible_v, + Shape>) { + return arg; // convertible to shape, so just return + } else { + return arg->shape(); + } +}; + +template +llvm::AllocaInst* IrEmitter::StoreTypes(std::string_view alloca_name, + T&& args) { + auto* types_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt64(args.size()), alloca_name, &b_); + + for (int64_t i = 0; i < args.size(); ++i) { + llvm::Value* slot_in_types_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), types_alloca, i); + Store(b_.getInt32(GetShape(args[i]).element_type()), slot_in_types_alloca); + } + return types_alloca; +}; + +template +llvm::Value* IrEmitter::StoreShapes(std::string_view alloca_name, T&& args) { + // Prepare metadata for all buffers + // Shapes metadata is encoded using contiguous flattened dimension values: + // { + // 1: DIMCOUNT_1, DIM_1[1], DIM_1[2], ..., DIM_1[DIMCOUNT_1], + // \______________DIMCOUNT_1 _______________/ + // 2: DIMCOUNT_2, DIM_2[1], DIM_2[2], ..., DIM_2[DIMCOUNT_2], + // \______________DIMCOUNT_2 _______________/ + // .: ... + // N: DIMCOUNT_N, DIM_N[1], DIM_N[2], ..., DIM_N[DIMCOUNT_N], + // \______________DIMCOUNT_N _______________/ + // } + // where N is `operand_count`, and `DIMCOUNT_i` is the # of dimensions + std::size_t total_dims = + absl::c_accumulate(args, int64_t{0}, [](int64_t acc, auto&& arg) { + return acc + GetShape(arg).dimensions().size(); + }); + int64_t encoded_shapes_size = args.size() // the dimension count identifiers + + total_dims; // the # of dimension values + + llvm::Value* shapes_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt64Ty(), b_.getInt64(encoded_shapes_size), alloca_name, &b_); + + int64_t slot_id = 0; + for (int64_t i = 0; i < args.size(); ++i) { + auto dims = GetShape(args[i]).dimensions(); + llvm::Value* alloca_slot = + ConstInBoundsGEP1_64(b_.getInt64Ty(), shapes_alloca, slot_id++); + // Store the operand count + Store(b_.getInt64(dims.size()), alloca_slot); + // Store the operand dimensions + for (int64_t dim : dims) { + alloca_slot = + ConstInBoundsGEP1_64(b_.getInt64Ty(), shapes_alloca, slot_id++); + Store(b_.getInt64(dim), alloca_slot); + } + } + CHECK_EQ(slot_id, encoded_shapes_size); // All slots are filled + return shapes_alloca; +}; + +llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call, + llvm::Value* output_address, + llvm::AllocaInst* operands_alloca) { + const auto& operands = absl::MakeSpan(custom_call->operands()); + const auto& shape = custom_call->shape(); + const auto& result_shapes = + shape.IsTuple() ? shape.tuple_shapes() : std::vector({shape}); + + auto operand_types_alloca = StoreTypes("meta_types_operands", operands); + auto operand_shapes_alloca = StoreShapes("meta_shapes_operands", operands); + + auto result_types_alloca = StoreTypes("meta_types_results", result_shapes); + auto result_shapes_alloca = StoreShapes("meta_shapes_results", result_shapes); + + const absl::string_view target = custom_call->custom_call_target(); // name + const absl::string_view opaque = custom_call->opaque(); + + const auto target_ref = llvm_ir::AsStringRef(target); + const auto opaque_ref = llvm_ir::AsStringRef(opaque); + + std::vector arguments = { + b_.CreateGlobalStringPtr(target_ref), // target_name_ptr + b_.getInt64(target.size()), // target_name_len + output_address, // output + operands_alloca, // inputs + b_.CreateGlobalStringPtr(opaque_ref), // opaque_str_ptr + b_.getInt64(opaque.size()), // opaque_str_len + GetStatusArgument(), // status_opaque + operand_types_alloca, // operand_types + b_.getInt64(operands.size()), // operand_count + operand_shapes_alloca, // operand_dims + result_types_alloca, // result_types + b_.getInt64(result_shapes.size()), // result_count + result_shapes_alloca, // result_dims + }; + + return EmitCallToFunc("__xla_cpu_runtime_HandleFfiCall", arguments, + b_.getVoidTy(), + /* does_not_throw = */ false, + /* only_accesses_arg_memory = */ true); +} + void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, int64_t element_count, PrimitiveType primitive_type, diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 93516138be50e4..bc99acd4f2b2f9 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -450,6 +450,16 @@ class IrEmitter : public DfsHloVisitorWithDefault, bool only_accesses_arg_memory = false, bool only_accesses_inaccessible_mem_or_arg_mem = false); + template + llvm::AllocaInst* StoreTypes(std::string_view alloca_name, T&& args); + template + llvm::Value* StoreShapes(std::string_view alloca_name, T&& args); + + // Emits a call to a proxy that builds an FFI call frame for `custom_call` + llvm::Value* EmitCallToFfi(HloCustomCallInstruction* custom_call, + llvm::Value* output_address, + llvm::AllocaInst* operands_alloca); + // Assignment of the buffers needed by the computation and their shape // information. const BufferAssignment& assignment_; diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc new file mode 100644 index 00000000000000..a962c998d3d827 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -0,0 +1,244 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime_handle_ffi_call.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" +#include "xla/primitive_util.h" +#include "xla/service/custom_call_status.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace ffi = xla::ffi; + +namespace { + +using Attribute = ffi::CallFrameBuilder::FlatAttribute; +using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; + +// TODO(heinsaar): This BuildAttributesMap() is originally an identical +// copy-paste of the same function in custom_call_thunk.cc +// May make sense to have one in a common place & reuse. +absl::StatusOr BuildAttributesMap(mlir::DictionaryAttr dict) { + AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + + auto integer = [&](mlir::IntegerAttr integer) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + }; + + auto fp = [&](mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(fp.getValue().convertToFloat()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } + }; + + auto str = [&](mlir::StringAttr str) { + attributes[name] = str.getValue().str(); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(kv.getValue()) + .Case(integer) + .Case(fp) + .Case(str) + .Default([&](mlir::Attribute) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute type for attribute: ", name)); + })); + } + return attributes; +} + +absl::Span DecodeDims(int64_t* encoded_dims_data) { + auto dims_count = encoded_dims_data[0]; + auto dims_begin = encoded_dims_data + 1; + return absl::MakeSpan(dims_begin, dims_begin + dims_count); +} + +// TODO(heinsaar): Once on C++20, this can (and should) be a local lambda with +// an explicit template parameter list. +class ArgInserter { + public: + template + explicit ArgInserter(B&& b) : b_(std::forward(b)) {} + + template + void operator()(Args&&... args) const { + b_.AddBufferArg(std::forward(args)...); + } + + private: + ffi::CallFrameBuilder& b_; +}; + +// TODO(heinsaar): Once on C++20, this can (and should) be a local lambda with +// an explicit template parameter list. +class RetInserter { + public: + template + explicit RetInserter(B&& b) : b_(std::forward(b)) {} + + template + void operator()(Args&&... args) const { + b_.AddBufferRet(std::forward(args)...); + } + + private: + ffi::CallFrameBuilder& b_; +}; + +template +void BuildBuffers(absl::Span types, int64_t* encoded_dims, + absl::Span address_space, Builder&& builder) { + int64_t dim_pos = 0; + for (int64_t i = 0; i < types.size(); ++i) { + auto dtype = static_cast(types[i]); + auto dims = DecodeDims(encoded_dims + dim_pos); + auto elem_count = absl::c_accumulate(dims, 1, std::multiplies()); + auto data_width = xla::primitive_util::ByteWidth(dtype) * elem_count; + + builder(tensorflow::se::DeviceMemoryBase(address_space[i], data_width), + /*type = */ dtype, + /*dims = */ dims); + dim_pos += 1; // Jumps over count value + dim_pos += dims.size(); // Jumps over all dimensions in a shape + } +} + +inline absl::Status BuildAndCallFfi( + std::string_view target_name, std::string_view backend_config, + absl::Span outputs, absl::Span inputs, + absl::Span result_types, int64_t* result_dims, + absl::Span operand_types, int64_t* operand_dims) { + CHECK_EQ(outputs.size(), result_types.size()); + CHECK_EQ(inputs.size(), operand_types.size()); + + if (absl::c_any_of(operand_types, [](int32_t type) { + return static_cast(type) == + xla::PrimitiveType::TUPLE; + })) { + return absl::InternalError( + "Tuple operands are not supported yet in typed FFI custom calls."); + } + + // Find the registered FFI handler for this custom call target. + absl::StatusOr registration = + ffi::FindHandler(target_name, "Host"); + + if (!registration.ok()) { + return absl::UnimplementedError( + absl::StrCat("No registered implementation for custom call to ", + target_name, " for Host.")); + } + + // For FFI handlers backend config must be a compatible MLIR dictionary. + mlir::MLIRContext mlir_context; + ffi::CallFrameBuilder::FlatAttributesMap attributes; + if (!backend_config.empty()) { + // Backend config not empty, so proceed to parse it into an MLIR attribute + // and build an MLIR compatible map of attributes out of it. + mlir::Attribute attr = mlir::parseAttribute(backend_config, &mlir_context); + if (auto dict = attr.dyn_cast_or_null()) { + TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + } else { + return absl::InternalError( + "Unsupported backend config. Expected a string parsable into " + "dictionary attribute"); + } + } + + ffi::CallFrameBuilder builder; + + // Forward the constructed attributes to the call frame + ffi::CallFrameBuilder::AttributesBuilder attrs; + attrs.Append(std::move(attributes)); + builder.AddAttributes(attrs.Build()); + + // Decode dimensions metadata into shapes and build operand & result buffers + BuildBuffers(operand_types, operand_dims, inputs, ArgInserter(builder)); + BuildBuffers(result_types, result_dims, outputs, RetInserter(builder)); + + ffi::CallFrame call_frame = builder.Build(); + return ffi::Call(registration->handler, call_frame); // Status +} + +} // namespace + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_HandleFfiCall( + const char* target_name_ptr, int64_t target_name_len, void* output, + void** inputs, const char* opaque_str_ptr, int64_t opaque_str_len, + void* status_opaque, int32_t* operand_types, int64_t operand_count, + int64_t* operand_dims, int32_t* result_types, int64_t result_count, + int64_t* result_dims) { + auto target_name = absl::string_view(target_name_ptr, target_name_len); + auto backend_config = absl::string_view(opaque_str_ptr, opaque_str_len); + auto xla_status = reinterpret_cast(status_opaque); + + void** outputs = &output; + if (result_count > 1) { // output is a tuple + outputs = reinterpret_cast(output); + } + + absl::Status status = BuildAndCallFfi( + target_name, backend_config, absl::MakeSpan(outputs, result_count), + absl::MakeSpan(inputs, operand_count), + absl::MakeSpan(result_types, result_count), result_dims, + absl::MakeSpan(operand_types, operand_count), operand_dims); + + if (!status.ok()) { + // In the future, status propagation will likely be possible. + // However, currently this has to pass through XlaCustomCallStatus + // which lacks functionality for status codes (it is fixed on INTERNAL) + XlaCustomCallStatusSetFailure(xla_status, status.message().data(), + status.message().size()); + } +} diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h new file mode 100644 index 00000000000000..e8afb236b2bb26 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_HANDLE_FFI_CALL_H_ +#define XLA_SERVICE_CPU_RUNTIME_HANDLE_FFI_CALL_H_ + +#include + +extern "C" { + +extern void __xla_cpu_runtime_HandleFfiCall( + const char* target_name_ptr, int64_t target_name_len, void* output, + void** inputs, const char* opaque_str_ptr, int64_t opaque_str_len, + void* status_opaque, int32_t* operand_types, int64_t operand_count, + int64_t* operand_dims, int32_t* result_types, int64_t result_count, + int64_t* result_dims); + +} // extern "C" + +#endif // XLA_SERVICE_CPU_RUNTIME_HANDLE_FFI_CALL_H_ diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 04f522c770cc83..b3373897ff4626 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/service/cpu/runtime_fft.h" #include "xla/service/cpu/runtime_fork_join.h" #include "xla/service/cpu/runtime_fp16.h" +#include "xla/service/cpu/runtime_handle_ffi_call.h" #include "xla/service/cpu/runtime_key_value_sort.h" #include "xla/service/cpu/runtime_matmul.h" #include "xla/service/cpu/runtime_matmul_acl.h" @@ -541,6 +542,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(TopKF32); REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); + REGISTER_CPU_RUNTIME_SYMBOL(HandleFfiCall); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) REGISTER_CPU_RUNTIME_SYMBOL(OneDnnMatMul); REGISTER_CPU_RUNTIME_SYMBOL(OneDnnSoftmax); diff --git a/third_party/xla/xla/service/llvm_ir/ir_builder_mixin.h b/third_party/xla/xla/service/llvm_ir/ir_builder_mixin.h index 23a00c242f9ad4..5e3e5fd4007503 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_builder_mixin.h +++ b/third_party/xla/xla/service/llvm_ir/ir_builder_mixin.h @@ -107,6 +107,12 @@ class IrBuilderMixin { std::forward(args)...); } + template + llvm::Value* ConstInBoundsGEP1_64(Args&&... args) { + return mixin_builder()->CreateConstInBoundsGEP1_64( + std::forward(args)...); + } + template llvm::Value* FAdd(Args&&... args) { return mixin_builder()->CreateFAdd(std::forward(args)...); diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index edf4ace76edbd1..8349a974a4ac1d 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -52,7 +52,13 @@ class DeviceMemoryBase { // region. An opaque pointer may be provided -- see header for details on the // opacity of that pointer. explicit DeviceMemoryBase(void *opaque = nullptr, uint64_t size = 0) - : opaque_(opaque), size_(size) {} + : opaque_(opaque), size_(size) { + // TODO(b/336267585): This constructor dangerously encourages + // DeviceMemoryBase(mem) which would imply + // DeviceMemoryBase(mem, 0) + // We should delete & resolve any dependencies. + // explicit DeviceMemoryBase(void *opaque) = delete; + } // Returns whether the backing memory is the null pointer. // A `== nullptr` convenience method is also provided. diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index cf39c2e416df65..626b5da14fa141 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1782,6 +1782,7 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index a31ad049aae45f..5fe3b90f8de51e 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" @@ -648,14 +649,7 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$FfiTupleRotate", } // namespace -// TODO(abanas): When #10056 (typed FFI support) is ready, this class can be -// replaced by a simple 'using FfiCustomCallTest = CustomCallTest;' -class FfiCustomCallTest : public CustomCallTest { - protected: - void SetUp() override { - GTEST_SKIP() << "Typed FFI is not supported yet on CPU"; - } -}; +using FfiCustomCallTest = CustomCallTest; XLA_TEST_F(FfiCustomCallTest, FfiReportsSuccess) { auto module = CreateNewVerifiedModule(); @@ -911,7 +905,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleAttrPointer) { auto n = 4.0f; auto ptr = reinterpret_cast(&n); builder.AddInstruction(HloInstruction::CreateCustomCall( - r0f32_, {constant}, "__xla_test$$FfiR0F32AddN", + r0f32_, {constant}, "__xla_test$$FfiR0F32AddNPointer", /*opaque=*/absl::StrFormat("{n = %d : i64}", ptr), /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); @@ -1089,6 +1083,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleOutput) { } XLA_TEST_F(FfiCustomCallTest, FfiTupleInput) { + GTEST_SKIP() << "Tuple inputs not yet implemented."; const char* const kModuleStr = R"( HloModule m