Skip to content

Commit

Permalink
Explicit Enzyme rules
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Nov 8, 2024
1 parent e99db76 commit a2ff88a
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 20 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Expand All @@ -36,7 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"]
BijectorsEnzymeCoreExt = ["EnzymeCore"]
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsMooncakeExt = "Mooncake"
Expand Down
55 changes: 55 additions & 0 deletions ext/BijectorsEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
module BijectorsEnzymeCoreExt

if isdefined(Base, :get_extension)
using EnzymeCore: Const, Duplicated, DuplicatedNoNeed, EnzymeRules
using Bijectors: find_alpha
else
using ..EnzymeCore: Const, Duplicated, DuplicatedNoNeed, EnzymeRules
using ..Bijectors: find_alpha
end

_add(x::Real, ::Nothing) = x
_add(x::Real, y::Real) = x + y

_muladd(x::Real, y::Real, ::Nothing) = x * y
_muladd(x::Real, y::Real, z::Real) = muladd(x, y, z)

function EnzymeRules.forward(
::Const{typeof(find_alpha)},
::Type{RT},
wt_y::Union{Const,Duplicated},
wt_u_hat::Union{Const,Duplicated},
b::Union{Const,Duplicated},
) where {RT<:Union{Const,DuplicatedNoNeed,Duplicated}}
# Compute primal value
Ω = find_alpha(wt_y.val, wt_u_hat.val, b.val)

# Early exit if no derivatives are requested
if RT <: Const
return Ω
end

if wt_y isa Const && wt_u_hat isa Const && b isa Const
# Trivial case: All partial derivatives are 0
Ω̇ = zero(Ω)
else
# In all other cases we have to compute the partial derivatives
# As in the rule for ChainRules, we reuse the following term in the computation of all derivatives
c = wt_u_hat.val * sech+ b.val)^2

x = b isa Duplicated ? (-c) * b.dval : nothing
y = wt_u_hat isa Duplicated ? _muladd(-tanh+ b.val), wt_u_hat.dval, x) : x
z = wt_y isa Duplicated ? _add(wt_y.dval, y) : y

Ω̇ = z / (1 + c)
end

if RT <: Duplicated
return Duplicated(Ω, Ω̇)
else
@assert RT <: DuplicatedNoNeed
return Ω̇
end
end

end # module
18 changes: 0 additions & 18 deletions ext/BijectorsEnzymeExt.jl

This file was deleted.

2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -33,6 +34,7 @@ Combinatorics = "1.0.2"
Compat = "3.46, 4.2"
DistributionsAD = "0.6.3"
Enzyme = "0.12.22"
EnzymeTestUtils = "0.1.8"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10.12"
Expand Down
15 changes: 15 additions & 0 deletions test/ad/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@testset "Enzyme: Bijectors.find_alpha" begin
x = randn()
y = expm1(randn())
z = randn()

@testset "forward" begin
@testset for RT in (Const, DuplicatedNoNeed, Duplicated),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated),
Tz in (Const, Duplicated)

test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
end
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ChainRulesTestUtils
using Combinatorics
using DistributionsAD
using Enzyme
using EnzymeTestUtils
using FiniteDifferences
using ForwardDiff
using Functors
Expand Down Expand Up @@ -68,6 +69,7 @@ end

if GROUP == "All" || GROUP == "AD"
include("ad/chainrules.jl")
include("ad/enzyme.jl")
include("ad/flows.jl")
include("ad/pd.jl")
include("ad/corr.jl")
Expand Down

0 comments on commit a2ff88a

Please sign in to comment.