Skip to content

Commit

Permalink
[MLIR][SPIRV-TO-LLVM] Support SPV_INTEL_split_barrier ops (#116648)
Browse files Browse the repository at this point in the history
Add conversion to LLVM for `SPV_INTEL_split_barrier` operations via
conversion to SPIR-V built-ins.

Signed-off-by: Victor Perez <victor.perez@codeplay.com>
  • Loading branch information
victor-eds authored Nov 22, 2024
1 parent 632c5d2 commit 05fcdd5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 8 deletions.
38 changes: 31 additions & 7 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,17 +1057,21 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
return call;
}

class ControlBarrierPattern
: public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
template <typename BarrierOpTy>
class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
public:
using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;

using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;

static constexpr StringRef getFuncName();

LogicalResult
matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii";
constexpr StringRef funcName = getFuncName();
Operation *symbolTable =
controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();

Type i32 = rewriter.getI32Type();

Expand Down Expand Up @@ -1266,6 +1270,24 @@ class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
}
};

template <>
constexpr StringRef
ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
return "_Z22__spirv_ControlBarrieriii";
}

template <>
constexpr StringRef
ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
return "_Z33__spirv_ControlBarrierArriveINTELiii";
}

template <>
constexpr StringRef
ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
return "_Z31__spirv_ControlBarrierWaitINTELiii";
}

/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
Expand Down Expand Up @@ -1899,7 +1921,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ReturnPattern, ReturnValuePattern,

// Barrier ops
ControlBarrierPattern,
ControlBarrierPattern<spirv::ControlBarrierOp>,
ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,

// Group reduction operations
GroupReducePattern<spirv::GroupIAddOp>,
Expand Down
27 changes: 26 additions & 1 deletion mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
// RUN: mlir-opt -convert-spirv-to-llvm -split-input-file %s | FileCheck %s

//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
Expand All @@ -21,3 +21,28 @@ spirv.func @control_barrier() "None" {
spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
spirv.Return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//

// CHECK-DAG: llvm.func spir_funccc @_Z33__spirv_ControlBarrierArriveINTELiii(i32, i32, i32) attributes {convergent, no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z31__spirv_ControlBarrierWaitINTELiii(i32, i32, i32) attributes {convergent, no_unwind, will_return}

// CHECK-LABEL: @split_barrier
spirv.func @split_barrier() "None" {
// CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z33__spirv_ControlBarrierArriveINTELiii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
spirv.INTEL.ControlBarrierArrive <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>

// CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z31__spirv_ControlBarrierWaitINTELiii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
spirv.INTEL.ControlBarrierWait <Workgroup>, <Workgroup>, <WorkgroupMemory>
spirv.Return
}

0 comments on commit 05fcdd5

Please sign in to comment.