From 05fcdd555eaac74717cd132ca434c90ae99381dd Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Fri, 22 Nov 2024 10:22:07 +0100 Subject: [PATCH] [MLIR][SPIRV-TO-LLVM] Support SPV_INTEL_split_barrier ops (#116648) Add conversion to LLVM for `SPV_INTEL_split_barrier` operations via conversion to SPIR-V built-ins. Signed-off-by: Victor Perez --- .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 38 +++++++++++++++---- .../SPIRVToLLVM/barrier-ops-to-llvm.mlir | 27 ++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index ef0508e7ef5f0..b11511f21d03d 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1057,17 +1057,21 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, return call; } -class ControlBarrierPattern - : public SPIRVToLLVMConversion { +template +class ControlBarrierPattern : public SPIRVToLLVMConversion { public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + using OpAdaptor = typename SPIRVToLLVMConversion::OpAdaptor; + + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + static constexpr StringRef getFuncName(); LogicalResult - matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor, + matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii"; + constexpr StringRef funcName = getFuncName(); Operation *symbolTable = - controlBarrierOp->getParentWithTrait(); + controlBarrierOp->template getParentWithTrait(); Type i32 = rewriter.getI32Type(); @@ -1266,6 +1270,24 @@ class GroupReducePattern : public SPIRVToLLVMConversion { } }; +template <> +constexpr StringRef +ControlBarrierPattern::getFuncName() { + return "_Z22__spirv_ControlBarrieriii"; +} + +template <> +constexpr StringRef +ControlBarrierPattern::getFuncName() { + return "_Z33__spirv_ControlBarrierArriveINTELiii"; +} + +template <> +constexpr StringRef +ControlBarrierPattern::getFuncName() { + return "_Z31__spirv_ControlBarrierWaitINTELiii"; +} + /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection /// should be reachable for conversion to succeed. The structure of the loop in /// LLVM dialect will be the following: @@ -1899,7 +1921,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns( ReturnPattern, ReturnValuePattern, // Barrier ops - ControlBarrierPattern, + ControlBarrierPattern, + ControlBarrierPattern, + ControlBarrierPattern, // Group reduction operations GroupReducePattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir index d53afeeea15d1..a5cae67a3d5c5 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-spirv-to-llvm -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// // spirv.ControlBarrierOp @@ -21,3 +21,28 @@ spirv.func @control_barrier() "None" { spirv.ControlBarrier , , spirv.Return } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.INTEL.SplitBarrier +//===----------------------------------------------------------------------===// + +// CHECK-DAG: llvm.func spir_funccc @_Z33__spirv_ControlBarrierArriveINTELiii(i32, i32, i32) attributes {convergent, no_unwind, will_return} +// CHECK-DAG: llvm.func spir_funccc @_Z31__spirv_ControlBarrierWaitINTELiii(i32, i32, i32) attributes {convergent, no_unwind, will_return} + +// CHECK-LABEL: @split_barrier +spirv.func @split_barrier() "None" { + // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z33__spirv_ControlBarrierArriveINTELiii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> () + spirv.INTEL.ControlBarrierArrive , , + + // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z31__spirv_ControlBarrierWaitINTELiii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> () + spirv.INTEL.ControlBarrierWait , , + spirv.Return +}