Skip to content

Commit

Permalink
[mlir][scf] Fix bug in software pipeliner and simplify the logic
Browse files Browse the repository at this point in the history
Fix bug when pipelining while interleaving stages. Re-do the logic to
only consider cloned operands when updating the use-def chain.

Differential Revision: https://reviews.llvm.org/D145598
  • Loading branch information
ThomasRaoux committed Mar 8, 2023
1 parent dea96e7 commit 117db47
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 85 deletions.
138 changes: 53 additions & 85 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,81 +294,6 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
return newForOp;
}

/// Replace any use of `target` with `replacement` in `op`'s operands or within
/// `op`'s nested regions.
static void replaceInOp(Operation *op, Value target, Value replacement) {
for (auto &use : llvm::make_early_inc_range(target.getUses())) {
if (op->isAncestor(use.getOwner()))
use.set(replacement);
}
}

/// Given a cloned op in the new kernel body, updates induction variable uses.
/// We replace it with a version incremented based on the stage where it is
/// used.
static void updateInductionVariableUses(RewriterBase &rewriter, Location loc,
Operation *newOp, Value newForIv,
unsigned maxStage, unsigned useStage,
unsigned step) {
rewriter.setInsertionPoint(newOp);
Value offset = rewriter.create<arith::ConstantIndexOp>(
loc, (maxStage - useStage) * step);
Value iv = rewriter.create<arith::AddIOp>(loc, newForIv, offset);
replaceInOp(newOp, newForIv, iv);
rewriter.setInsertionPointAfter(newOp);
}

/// If the value is a loop carried value coming from stage N + 1 remap, it will
/// become a direct use.
static void updateIterArgUses(RewriterBase &rewriter, IRMapping &bvm,
Operation *newOp, ForOp oldForOp, ForOp newForOp,
unsigned useStage,
const DenseMap<Operation *, unsigned> &stages) {

for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) {
Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i);
Operation *dep = yieldedVal.getDefiningOp();
if (!dep)
continue;
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
if (stageDep->second != useStage + 1)
continue;
Value replacement = bvm.lookup(yieldedVal);
replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement);
}
}

/// For operands defined in a previous stage we need to remap it to use the
/// correct region argument. We look for the right version of the Value based
/// on the stage where it is used.
static void updateCrossStageUses(
RewriterBase &rewriter, Operation *newOp, IRMapping &bvm, ForOp newForOp,
unsigned useStage, const DenseMap<Operation *, unsigned> &stages,
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
// Because we automatically cloned the sub-regions, there's no simple way
// to walk the nested regions in pairs of (oldOps, newOps), so we just
// traverse the set of remapped loop arguments, filter which ones are
// relevant, and replace any uses.
for (auto [remapPair, newIterIdx] : loopArgMap) {
auto [crossArgValue, stageIdx] = remapPair;
Operation *def = crossArgValue.getDefiningOp();
assert(def);
unsigned stageDef = stages.lookup(def);
if (useStage <= stageDef || useStage - stageDef != stageIdx)
continue;

// Use "lookupOrDefault" for the target value because some operations
// are remapped, while in other cases the original will be present.
Value target = bvm.lookupOrDefault(crossArgValue);
Value replacement = newForOp.getRegionIterArg(newIterIdx);

// Replace uses in the new op's operands and any nested uses.
replaceInOp(newOp, target, replacement);
}
}

void LoopPipelinerInternal::createKernel(
scf::ForOp newForOp,
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
Expand Down Expand Up @@ -400,16 +325,59 @@ void LoopPipelinerInternal::createKernel(
for (Operation *op : opOrder) {
int64_t useStage = stages[op];
auto *newOp = rewriter.clone(*op, mapping);

// Within the kernel body, update uses of the induction variable, uses of
// the original iter args, and uses of cross stage values.
updateInductionVariableUses(rewriter, forOp.getLoc(), newOp,
newForOp.getInductionVar(), maxStage,
stages[op], step);
updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage,
stages);
updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages,
loopArgMap);
SmallVector<OpOperand *> operands;
// Collect all the operands for the cloned op and its nested ops.
op->walk([&operands](Operation *nestedOp) {
for (OpOperand &operand : nestedOp->getOpOperands()) {
operands.push_back(&operand);
}
});
for (OpOperand *operand : operands) {
Operation *nestedNewOp = mapping.lookup(operand->getOwner());
// Special case for the induction variable uses. We replace it with a
// version incremented based on the stage where it is used.
if (operand->get() == forOp.getInductionVar()) {
rewriter.setInsertionPoint(newOp);
Value offset = rewriter.create<arith::ConstantIndexOp>(
forOp.getLoc(), (maxStage - stages[op]) * step);
Value iv = rewriter.create<arith::AddIOp>(
forOp.getLoc(), newForOp.getInductionVar(), offset);
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
rewriter.setInsertionPointAfter(newOp);
continue;
}
auto arg = operand->get().dyn_cast<BlockArgument>();
if (arg && arg.getOwner() == forOp.getBody()) {
// If the value is a loop carried value coming from stage N + 1 remap,
// it will become a direct use.
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
Operation *dep = ret.getDefiningOp();
if (!dep)
continue;
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
assert(stageDep->second == useStage + 1);
nestedNewOp->setOperand(operand->getOperandNumber(),
mapping.lookupOrDefault(ret));
continue;
}
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
Operation *def = operand->get().getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
if (stageDef == stages.end() || stageDef->second == useStage)
continue;
auto remap = loopArgMap.find(
std::make_pair(operand->get(), useStage - stageDef->second));
assert(remap != loopArgMap.end());
nestedNewOp->setOperand(operand->getOperandNumber(),
newForOp.getRegionIterArgs()[remap->second]);
}

if (predicates[useStage]) {
newOp = predicateFn(newOp, predicates[useStage], rewriter);
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,47 @@ func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result
} { __test_pipelining_loop__ }
return
}

// -----

// CHECK-LABEL: @backedge_mix_order
// CHECK-SAME: (%[[A:.*]]: memref<?xf32>) -> f32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
// Prologue:
// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
// Kernel:
// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[ARG1:.*]] = %[[L0]], %[[ARG2:.*]] = %[[L1]]) -> (f32, f32, f32) {
// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C1]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[C]], %[[ARG1]] : f32
// CHECK-NEXT: %[[IV3:.*]] = arith.addi %[[IV]], %[[C1]] : index
// CHECK-NEXT: %[[IV4:.*]] = arith.addi %[[IV3]], %[[C1]] : index
// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[IV4]]] : memref<?xf32>
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ARG2]], %[[MUL0]] : f32
// CHECK-NEXT: scf.yield %[[MUL1]], %[[L2]], %[[L3]] : f32, f32, f32
// CHECK-NEXT: }
// Epilogue:
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[R]]#0, %[[R]]#1 : f32
// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[R]]#2, %[[MUL1]] : f32
// CHECK-NEXT: return %[[MUL2]] : f32
func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%cf = arith.constant 2.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
%A2_elem = arith.mulf %arg0, %A_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
%i1 = arith.addi %i0, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : index
%A1_elem = memref.load %A[%i1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
%A3_elem = arith.mulf %A1_elem, %A2_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 4 } : f32
scf.yield %A3_elem : f32
} { __test_pipelining_loop__ }
return %r : f32
}

0 comments on commit 117db47

Please sign in to comment.