Skip to content

Commit

Permalink
Enzyme ops removal for stablehlo.if
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Jan 7, 2025
1 parent ff4cc48 commit 3abe6b9
Showing 1 changed file with 196 additions and 0 deletions.
196 changes: 196 additions & 0 deletions src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
#include "Enzyme/MLIR/Passes/RemovalUtils.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -1822,6 +1823,199 @@ class AutoDiffSort
}
};

static void removalBlockExplore(Block *block, IRMapping &mapping,
OpBuilder &builder,
llvm::SmallDenseSet<Value> &gradients,
llvm::DenseMap<Value, CacheInfo> &caches) {
for (auto it = block->begin(), e = block->end(); it != e;) {
Operation *op = &*it;

if (auto setOp = dyn_cast<enzyme::SetOp>(op)) {
auto grad = setOp.getGradient();
auto value = setOp.getValue();
mapping.map(grad, value);
gradients.insert(grad);
}

if (auto getOp = dyn_cast<enzyme::GetOp>(op)) {
auto grad = getOp.getGradient();
Value value = mapping.lookupOrNull(getOp.getGradient());
if (!value) {
value = builder.create<enzyme::GetOp>(
getOp->getLoc(), getOp.getResult().getType(), grad);
mapping.map(grad, value);
}
getOp.getResult().replaceAllUsesWith(value);
}

if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
CacheInfo info(pushOp.getCache());

Value pushedValue = info.pushedValue();

// Then we can push the value before the if, if it is defined before the
// if
if (pushedValue.getParentBlock() != block) {
builder.create<enzyme::PushOp>(pushOp->getLoc(), pushOp.getCache(),
pushedValue);

++it; // Increment iterator to allow in place deletion
pushOp->erase();

// Move the pop before the other if
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(info.popOp->getParentOp());

auto newPop = builder.create<enzyme::PopOp>(
info.popOp->getLoc(), pushedValue.getType(), info.popOp.getCache());
info.popOp.getResult().replaceAllUsesWith(newPop);
info.popOp->erase();

continue;
}

if (caches.contains(pushedValue)) {
info = info.merge(caches.lookup(pushedValue));
}
caches[pushedValue] = info;
}

++it;
}
}

struct IfOpEnzymeOpsRemover
: public EnzymeOpsRemoverOpInterface::ExternalModel<IfOpEnzymeOpsRemover,
stablehlo::IfOp> {
LogicalResult removeEnzymeOps(Operation *op) const {
// Gradients:
//
// For each set in a branch, we instead set after the if by using the
// return value.
//
// if %pred {
// enzyme.set %grad, %2
// } else {
// }
//
// %0 = enzyme.get %grad
// %1 = if %pred {
// return %2
// } else {
// return %0
// }
// enzyme.set %grad, %1
//
// For each get in a branch, we get before and use that instead of the
// get.

// Caches:
//
// For each push, push after the if instead add a dummy value in the other
// branch.
//
// For each pop in the reverse if, pop before the if instead of inside a
// branch.

auto ifOp = cast<IfOp>(op);

Block *trueBlock = &ifOp.getTrueBranch().front(),
*falseBlock = &ifOp.getFalseBranch().front();

if (enzyme::removeOpsWithinBlock(trueBlock).failed() ||
enzyme::removeOpsWithinBlock(falseBlock).failed()) {
return failure();
}

// Gradients whose value is set in either branches.
llvm::SmallDenseSet<Value> gradients;

// We assume pushes are exclusive.
llvm::DenseMap<Value, CacheInfo> pushedCaches;

// Grad to value
IRMapping trueMapping, falseMapping;
OpBuilder builder(ifOp);

removalBlockExplore(trueBlock, trueMapping, builder, gradients,
pushedCaches);
removalBlockExplore(falseBlock, falseMapping, builder, gradients,
pushedCaches);

Operation *trueTerm = trueBlock->getTerminator();
Operation *falseTerm = falseBlock->getTerminator();

for (auto grad : gradients) {
auto trueValue = trueMapping.lookupOrNull(grad);
if (!trueValue) {
trueValue = builder.create<enzyme::GetOp>(
grad.getLoc(),
grad.getType().cast<enzyme::GradientType>().getBasetype(), grad);
}
trueTerm->insertOperands(trueTerm->getNumOperands(),
ValueRange(trueValue));

auto falseValue = falseMapping.lookupOrNull(grad);
if (!falseValue) {
falseValue = builder.create<enzyme::GetOp>(
grad.getLoc(),
grad.getType().cast<enzyme::GradientType>().getBasetype(), grad);
}
falseTerm->insertOperands(falseTerm->getNumOperands(),
ValueRange(falseValue));
}

for (auto &[pushedValue, info] : pushedCaches) {
Value dummy =
pushedValue.getType().cast<AutoDiffTypeInterface>().createNullValue(
builder, pushedValue.getLoc());

Value trueValue =
pushedValue.getParentBlock() == trueBlock ? pushedValue : dummy;
Value falseValue =
pushedValue.getParentBlock() == falseBlock ? pushedValue : dummy;

trueTerm->insertOperands(trueTerm->getNumOperands(),
ValueRange(trueValue));
falseTerm->insertOperands(falseTerm->getNumOperands(),
ValueRange(falseValue));
}

auto newIf = builder.create<stablehlo::IfOp>(
ifOp->getLoc(), trueTerm->getOperandTypes(), ifOp.getPred());
newIf.getTrueBranch().takeBody(ifOp.getTrueBranch());
newIf.getFalseBranch().takeBody(ifOp.getFalseBranch());

size_t idx = ifOp->getNumResults();
for (auto grad : gradients) {
builder.create<enzyme::SetOp>(grad.getLoc(), grad, newIf->getResult(idx));
idx++;
}

for (auto &[pushedValue, info] : pushedCaches) {
builder.create<enzyme::PushOp>(info.pushOp->getLoc(),
info.initOp.getResult(),
newIf->getResult(idx));
info.pushOp->erase();

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(info.popOp->getParentOp());

auto newPop = builder.create<enzyme::PopOp>(
info.popOp->getLoc(), info.popOp.getResult().getType(),
info.popOp.getCache());
info.popOp.getResult().replaceAllUsesWith(newPop);
info.popOp->erase();

idx++;
}

ifOp->erase();

return success();
}
};

} // namespace

void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
Expand All @@ -1832,6 +2026,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(

// SortOp::attachInterface<AutoDiffSort>(*context);

IfOp::attachInterface<IfOpEnzymeOpsRemover>(*context);

WhileOp::attachInterface<ADDataFlowWhileOp>(*context);
SortOp::attachInterface<ADDataFlowSortOp>(*context);
ScatterOp::attachInterface<ADDataFlowScatterOp>(*context);
Expand Down

0 comments on commit 3abe6b9

Please sign in to comment.