diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index a3e9fad6..dca5ab38 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -15,6 +15,7 @@ #include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h" #include "Enzyme/MLIR/Interfaces/GradientUtils.h" #include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" +#include "Enzyme/MLIR/Passes/RemovalUtils.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -1822,6 +1823,199 @@ class AutoDiffSort } }; +static void removalBlockExplore(Block *block, IRMapping &mapping, + OpBuilder &builder, + llvm::SmallDenseSet &gradients, + llvm::DenseMap &caches) { + for (auto it = block->begin(), e = block->end(); it != e;) { + Operation *op = &*it; + + if (auto setOp = dyn_cast(op)) { + auto grad = setOp.getGradient(); + auto value = setOp.getValue(); + mapping.map(grad, value); + gradients.insert(grad); + } + + if (auto getOp = dyn_cast(op)) { + auto grad = getOp.getGradient(); + Value value = mapping.lookupOrNull(getOp.getGradient()); + if (!value) { + value = builder.create( + getOp->getLoc(), getOp.getResult().getType(), grad); + mapping.map(grad, value); + } + getOp.getResult().replaceAllUsesWith(value); + } + + if (auto pushOp = dyn_cast(op)) { + CacheInfo info(pushOp.getCache()); + + Value pushedValue = info.pushedValue(); + + // Then we can push the value before the if, if it is defined before the + // if + if (pushedValue.getParentBlock() != block) { + builder.create(pushOp->getLoc(), pushOp.getCache(), + pushedValue); + + ++it; // Increment iterator to allow in place deletion + pushOp->erase(); + + // Move the pop before the other if + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(info.popOp->getParentOp()); + + auto newPop = builder.create( + info.popOp->getLoc(), pushedValue.getType(), info.popOp.getCache()); + info.popOp.getResult().replaceAllUsesWith(newPop); + info.popOp->erase(); + + continue; + } + + if (caches.contains(pushedValue)) { + info = info.merge(caches.lookup(pushedValue)); + } + caches[pushedValue] = info; + } + + ++it; + } +} + +struct IfOpEnzymeOpsRemover + : public EnzymeOpsRemoverOpInterface::ExternalModel { + LogicalResult removeEnzymeOps(Operation *op) const { + // Gradients: + // + // For each set in a branch, we instead set after the if by using the + // return value. + // + // if %pred { + // enzyme.set %grad, %2 + // } else { + // } + // + // %0 = enzyme.get %grad + // %1 = if %pred { + // return %2 + // } else { + // return %0 + // } + // enzyme.set %grad, %1 + // + // For each get in a branch, we get before and use that instead of the + // get. + + // Caches: + // + // For each push, push after the if instead add a dummy value in the other + // branch. + // + // For each pop in the reverse if, pop before the if instead of inside a + // branch. + + auto ifOp = cast(op); + + Block *trueBlock = &ifOp.getTrueBranch().front(), + *falseBlock = &ifOp.getFalseBranch().front(); + + if (enzyme::removeOpsWithinBlock(trueBlock).failed() || + enzyme::removeOpsWithinBlock(falseBlock).failed()) { + return failure(); + } + + // Gradients whose value is set in either branches. + llvm::SmallDenseSet gradients; + + // We assume pushes are exclusive. + llvm::DenseMap pushedCaches; + + // Grad to value + IRMapping trueMapping, falseMapping; + OpBuilder builder(ifOp); + + removalBlockExplore(trueBlock, trueMapping, builder, gradients, + pushedCaches); + removalBlockExplore(falseBlock, falseMapping, builder, gradients, + pushedCaches); + + Operation *trueTerm = trueBlock->getTerminator(); + Operation *falseTerm = falseBlock->getTerminator(); + + for (auto grad : gradients) { + auto trueValue = trueMapping.lookupOrNull(grad); + if (!trueValue) { + trueValue = builder.create( + grad.getLoc(), + grad.getType().cast().getBasetype(), grad); + } + trueTerm->insertOperands(trueTerm->getNumOperands(), + ValueRange(trueValue)); + + auto falseValue = falseMapping.lookupOrNull(grad); + if (!falseValue) { + falseValue = builder.create( + grad.getLoc(), + grad.getType().cast().getBasetype(), grad); + } + falseTerm->insertOperands(falseTerm->getNumOperands(), + ValueRange(falseValue)); + } + + for (auto &[pushedValue, info] : pushedCaches) { + Value dummy = + pushedValue.getType().cast().createNullValue( + builder, pushedValue.getLoc()); + + Value trueValue = + pushedValue.getParentBlock() == trueBlock ? pushedValue : dummy; + Value falseValue = + pushedValue.getParentBlock() == falseBlock ? pushedValue : dummy; + + trueTerm->insertOperands(trueTerm->getNumOperands(), + ValueRange(trueValue)); + falseTerm->insertOperands(falseTerm->getNumOperands(), + ValueRange(falseValue)); + } + + auto newIf = builder.create( + ifOp->getLoc(), trueTerm->getOperandTypes(), ifOp.getPred()); + newIf.getTrueBranch().takeBody(ifOp.getTrueBranch()); + newIf.getFalseBranch().takeBody(ifOp.getFalseBranch()); + + size_t idx = ifOp->getNumResults(); + for (auto grad : gradients) { + builder.create(grad.getLoc(), grad, newIf->getResult(idx)); + idx++; + } + + for (auto &[pushedValue, info] : pushedCaches) { + builder.create(info.pushOp->getLoc(), + info.initOp.getResult(), + newIf->getResult(idx)); + info.pushOp->erase(); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(info.popOp->getParentOp()); + + auto newPop = builder.create( + info.popOp->getLoc(), info.popOp.getResult().getType(), + info.popOp.getCache()); + info.popOp.getResult().replaceAllUsesWith(newPop); + info.popOp->erase(); + + idx++; + } + + ifOp->erase(); + + return success(); + } +}; + } // namespace void mlir::enzyme::registerStableHLODialectAutoDiffInterface( @@ -1832,6 +2026,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( // SortOp::attachInterface(*context); + IfOp::attachInterface(*context); + WhileOp::attachInterface(*context); SortOp::attachInterface(*context); ScatterOp::attachInterface(*context);