Skip to content

Commit

Permalink
Expose materialize functions in Chlo to Stablehlo lowering (#2665)
Browse files Browse the repository at this point in the history
We want to perform constant propogation through `chlo.lgamma` in
Enzyme-JaX


[Kevin](EnzymeAD/Enzyme-JAX#182 (comment))
mentioned he was open to exposing some materialize functions (which are
currently static, and not callable from [our
pass](https://github.com/EnzymeAD/Enzyme-JAX/blob/main/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp)
atm)



@wsmoses @GleasonK
  • Loading branch information
vimarsh6739 authored Dec 15, 2024
1 parent 2f6be83 commit 38fe0f4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 35 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,7 @@ cc_library(
"stablehlo/transforms/VhloToVersion.cpp",
],
hdrs = [
"stablehlo/transforms/ChloDecompositionUtils.h",
"stablehlo/transforms/MapStablehloToVhlo.h",
"stablehlo/transforms/PassUtils.h",
"stablehlo/transforms/Passes.h",
Expand Down
36 changes: 36 additions & 0 deletions stablehlo/transforms/ChloDecompositionUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_
#define STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_

#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace stablehlo {

// Utility functions used in the Chlo to stablehlo legalization.

Value materializeLgamma(OpBuilder &rewriter, Location loc, ValueRange args);

Value materializeDigamma(OpBuilder &rewriter, Location loc, ValueRange args);

Value materializeZeta(OpBuilder &rewriter, Location loc, ValueRange args);

Value materializePolygamma(OpBuilder &rewriter, Location loc, ValueRange args);

} // namespace stablehlo
} // namespace mlir

#endif // STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_
81 changes: 46 additions & 35 deletions stablehlo/transforms/ChloLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "stablehlo/dialect/BroadcastUtils.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/ChloDecompositionUtils.h"
#include "stablehlo/transforms/PassUtils.h"
#include "stablehlo/transforms/Passes.h"

Expand Down Expand Up @@ -462,8 +463,7 @@ struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {

template <typename FTy>
static Value materializeChebyshevPolynomialApproximation(
ConversionPatternRewriter &rewriter, Location loc, Value x,
ArrayRef<FTy> coefficients) {
OpBuilder &rewriter, Location loc, Value x, ArrayRef<FTy> coefficients) {
Value b0 = getConstantLike(rewriter, loc, 0.0, x);
Value b1 = getConstantLike(rewriter, loc, 0.0, x);
Value b2 = getConstantLike(rewriter, loc, 0.0, x);
Expand All @@ -483,9 +483,10 @@ static Value materializeChebyshevPolynomialApproximation(
}

template <typename FTy>
static Value materializeBesselI1eApproximation(
ConversionPatternRewriter &rewriter, Location loc, Value x,
ArrayRef<FTy> kI1eCoeffsA, ArrayRef<FTy> kI1eCoeffsB) {
static Value materializeBesselI1eApproximation(OpBuilder &rewriter,
Location loc, Value x,
ArrayRef<FTy> kI1eCoeffsA,
ArrayRef<FTy> kI1eCoeffsB) {
Value z = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
Value half = getConstantLike(rewriter, loc, 0.5, x);
Value two = getConstantLike(rewriter, loc, 2.0, x);
Expand Down Expand Up @@ -515,8 +516,8 @@ static Value materializeBesselI1eApproximation(
loc, rewriter.create<mlir::stablehlo::SignOp>(loc, x), select);
}

Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value materializeBesselI1eApproximationF32(OpBuilder &rewriter, Location loc,
ValueRange args) {
Value x = args.front();
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
"expect f32 element type");
Expand All @@ -541,8 +542,9 @@ Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
kI1eCoeffsB);
}

static Value materializeBesselI1eApproximationF64(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
static Value materializeBesselI1eApproximationF64(OpBuilder &rewriter,
Location loc,
ValueRange args) {
Value x = args.front();
assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
"expect f64 element type");
Expand Down Expand Up @@ -586,8 +588,8 @@ static Value materializeBesselI1eApproximationF64(
static Value materializeWithUpcast(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args,
FloatType minPrecisionTy,
Value callback(ConversionPatternRewriter &,
Location, ValueRange)) {
Value callback(OpBuilder &, Location,
ValueRange)) {
Type originalTy = getElementTypeOrSelf(args.front().getType());
auto floatOriginalTy = dyn_cast<FloatType>(originalTy);
bool needsUpcast =
Expand Down Expand Up @@ -645,9 +647,9 @@ struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> {
};

template <typename FTy>
static Value materializePolynomialApproximation(
ConversionPatternRewriter &rewriter, Location loc, Value x,
ArrayRef<FTy> coefficients) {
static Value materializePolynomialApproximation(OpBuilder &rewriter,
Location loc, Value x,
ArrayRef<FTy> coefficients) {
if (coefficients.empty()) return getConstantLike(rewriter, loc, 0.0, x);

Value poly = getConstantLike(rewriter, loc, coefficients[0], x);
Expand Down Expand Up @@ -836,7 +838,7 @@ static Value materializeErfcApproximationF64(
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
static Value materializeErfcApproximationF32ForMagnitudeGeOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
OpBuilder &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
"expect f32 element type");
Expand Down Expand Up @@ -902,7 +904,7 @@ static Value materializeErfcApproximationF32ForMagnitudeGeOne(
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
static Value materializeErfApproximationF32ForMagnitudeLeOne(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
OpBuilder &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
"expect f32 element type");
Expand All @@ -921,8 +923,8 @@ static Value materializeErfApproximationF32ForMagnitudeLeOne(
}

// This is the same approximation as used in Eigen.
static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
static Value materializeErfApproximationF32(OpBuilder &rewriter, Location loc,
ValueRange args) {
Value x = args.front();
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
"expect f32 element type");
Expand Down Expand Up @@ -958,8 +960,8 @@ static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
erf, ubErf);
}

static Value materializeErfcApproximationF32(
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
static Value materializeErfcApproximationF32(OpBuilder &rewriter, Location loc,
ValueRange args) {
Value x = args.front();
assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
"expect f32 element type");
Expand Down Expand Up @@ -1041,8 +1043,7 @@ struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> {
}
};

static Value erfInv32(ConversionPatternRewriter &b, Location loc,
ValueRange args) {
static Value erfInv32(OpBuilder &b, Location loc, ValueRange args) {
constexpr int kDegree = 9;
constexpr std::array<float, 9> wLessThan5Constants = {
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
Expand Down Expand Up @@ -1248,6 +1249,8 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};

} // namespace

// Compute the Lgamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
Expand All @@ -1257,8 +1260,7 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
static Value materializeLgamma(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value materializeLgamma(OpBuilder &rewriter, Location loc, ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be
Expand Down Expand Up @@ -1393,6 +1395,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter,
getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
}

namespace {

// Express `cosh` as
// cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2))
Expand All @@ -1403,8 +1407,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter,
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
static Value materializeCoshApproximation(OpBuilder &rewriter, Location loc,
ValueRange operands) {
mlir::chlo::CoshOp::Adaptor transformed(operands);
Value x = transformed.getOperand();

Expand All @@ -1431,6 +1435,8 @@ struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
}
};

} // namespace

// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
Expand All @@ -1439,8 +1445,7 @@ struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
static Value materializeDigamma(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value materializeDigamma(OpBuilder &rewriter, Location loc, ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be
Expand Down Expand Up @@ -1545,14 +1550,16 @@ static Value materializeDigamma(ConversionPatternRewriter &rewriter,
digamma);
}

namespace {

static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc,
Value val) {
auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
return getConstantLike(
b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
}

static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
static Value materializeZeta(OpBuilder &rewriter, Location loc,
ValueRange args) {
// Implementation ported from:
// https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917
Expand Down Expand Up @@ -1703,8 +1710,9 @@ static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
return output;
}

static Value materializePolygamma(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
} // namespace

Value materializePolygamma(OpBuilder &rewriter, Location loc, ValueRange args) {
mlir::chlo::PolygammaOp::Adaptor transformed(args);
Value n = transformed.getN();
Value x = transformed.getX();
Expand Down Expand Up @@ -1747,6 +1755,8 @@ static Value materializePolygamma(ConversionPatternRewriter &rewriter,
result);
}

namespace {

struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -1901,8 +1911,9 @@ struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
static Value materializeSinhApproximationForLargeX(
ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) {
static Value materializeSinhApproximationForLargeX(OpBuilder &rewriter,
Location loc,
ValueRange operands) {
mlir::chlo::SinhOp::Adaptor transformed(operands);
Value x = transformed.getOperand();

Expand All @@ -1918,8 +1929,8 @@ static Value materializeSinhApproximationForLargeX(
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
static Value materializeSinhApproximation(OpBuilder &rewriter, Location loc,
ValueRange operands) {
Value largeSinhResult =
materializeSinhApproximationForLargeX(rewriter, loc, operands);

Expand Down

0 comments on commit 38fe0f4

Please sign in to comment.