Replacing the EnsembleProblem with an equivalent simple for loop breaks the training #54

gabrevaya opened this issue Sep 15, 2021 · 2 comments


If we change the differential equations solving inside the diffeq_layer from the current EnsembleProblem to a simple for loop, which forward pass gives the exact same result, the whole training breaks. You can quickly try this by overriding the diffeq_layer function and running the pendulum example or tutorial:

import LatentDiffEq.Decoder
import LatentDiffEq.diffeq_layer
import LatentDiffEq.transform_after_diffeq

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs

    prob = remake(prob; tspan = (t[1],t[end]))
    sols = Array{Float32,2}[]

    for i in 1:size(ẑ₀,2)
        prob = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
        sol = solve(prob, solver; sensealg = sensealg, saveat = t, kwargs...)
        push!(sols, Array(sol))
    end= Flux.stack(sols, 2)
    ẑ = transform_after_diffeq(ẑ, decoder.diffeq)

# Identity by default
transform_after_diffeq(x, diffeq) = x

For reference, this is the original diffeq_layer function:

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs

    # Function definition for ensemble problem
    prob_func(prob,i,repeat) = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])

    # Check if solve was successful, if not, return NaNs to avoid problems with dimensions
    output_func(sol, i) = sol.retcode == :Success ? (Array(sol), false) : (fill(NaN32,(size(ẑ₀, 1), length(t))), false)

    ## Adapt problem to given time span and create ensemble problem definition
    prob = remake(prob; tspan = (t[1], t[end]))
    ens_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func)

    ## Solve= solve(ens_prob, solver, EnsembleThreads(); sensealg = sensealg, trajectories = size(θ̂, 2), saveat = t, kwargs...)

    # Transform the resulting output (mainly used for Kuramoto-like systems)= transform_after_diffeq(ẑ, decoder.diffeq)
    ẑ = permutedims(ẑ, [1,3,2])


The forward output is exactly the same however the training of the model malfunctions. After a few epochs, the loss function tends to greatly increase ( >> 1e6 after 50 epochs in the default pendulum example).

I'll try to get a MWE and compare the gradients. This might be hint of a sensitivity analysis, zygote or primitives issue, which could be the reason why the pendulum example in the Python implementation of GOKU nets converges faster and is more robust under different random seeds when using exactly the same architecture, hyperparameters and initializations.

Maybe I'm doing something wrong in terms of DiffEqFlux with the simple for loop version? @ChrisRackauckas, a priori do you see something wrong here?

Owner Author

Here is a MWE in the context of LatentDiffEq.jl. Later I'll create a more general (minimal) MWE.

using LatentDiffEq
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Random
import LatentDiffEq.Decoder
import LatentDiffEq.diffeq_layer


struct Pendulum{P,S,T,K}
    function Pendulum(; kwargs...)
        # Parameters and initial conditions only
        # used to initialize the ODE problem
        u₀ = Float32[1.0, 1.0]
        p = Float32[1.]
        tspan = (0.f0, 1.f0)

        # Define differential equations
        function f!(du, u, p, t)
                x, y = u
                G = 10.0f0
                L = p[1]
                du[1] = y
                du[2] =  -G/L*sin(x)

        # Build ODE Problem
        prob = ODEProblem(f!, u₀, tspan, p)

        # Chose a solver and sensitivity algorithm
        solver = Tsit5()
        sensalg = ForwardDiffSensitivity()
        # sensalg = BacksolveAdjoint()

        # sensalg = InterpolatingAdjoint()
        # sensalg = QuadratureAdjoint()
        P = typeof(prob)
        S = typeof(solver)
        T = typeof(sensalg)
        K = typeof(kwargs)
        new{P,S,T,K}(prob, solver, sensalg, kwargs)

model_type = GOKU()
diffeq = Pendulum()
input_dim = 784

encoder_layers, decoder_layers = default_layers(model_type, input_dim, diffeq)
model = LatentDiffEqModel(model_type, encoder_layers, decoder_layers)
ps = Flux.params(model)
l̂ = (rand(Float32, 2, 64), rand(Float32, 1, 64))
t = range(0.f0, step=0.05, length=50)
loss(model, l̂, t) = sum(diffeq_layer(model.decoder, l̂, t) .- 1)

function evaluate(model, l̂, t)
    out = diffeq_layer(model.decoder, l̂, t)
    l_direct = loss(model, l̂, t)

    l, back = Flux.pullback(ps) do
        loss(model, l̂, t)

    return out, l_direct, l, back

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs
    # Function definition for ensemble problem
    prob_func(prob,i,repeat) = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
    # Check if solve was successful, if not, return NaNs to avoid problems with dimensions matches
    output_func(sol, i) = sol.retcode == :Success ? (Array(sol), false) : (fill(NaN32,(size(ẑ₀, 1), length(t))), false)
    ## Adapt problem to given time span and create ensemble problem definition
    prob = remake(prob; tspan = (t[1],t[end]))
    ens_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func)
    ## Solve= solve(ens_prob, solver, EnsembleSerial(); sensealg = sensealg, trajectories = size(θ̂, 2), saveat = t, kwargs...)
    ẑ = permutedims(ẑ, [1,3,2])

res1 = evaluate(model, l̂, t)

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs

    prob = remake(prob; tspan = (t[1],t[end]))
    sols = Array{Float32,2}[]

    for i in 1:size(ẑ₀,2)
        prob = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
        sol = solve(prob, solver; sensealg = sensealg, saveat = t, kwargs...)
        push!(sols, Array(sol))
    end= Flux.stack(sols, 2)

res2 = evaluate(model, l̂, t)

@show res1 .== res2

@show res1[2]
@show res1[3]

@show res2[2]
@show res2[3]

If ForwardDiffSensitivity() or BacksolveAdjoint() are used as sensalg (chosen inside the Pendulum struct), the results are the following:

res1 .== res2 = (true, true, true, false)
res1[2] = -6916.1006f0
res1[3] = -6916.1006f0
res2[2] = -6916.1006f0
res2[3] = -6916.1006f0

If InterpolatingAdjoint() or QuadratureAdjoint() are used as sensalg:

res1 .== res2 = (true, true, false, false)
res1[2] = -6916.1006f0
res1[3] = -6916.099838162956
res2[2] = -6916.1006f0
res2[3] = -6916.1006f0

So in both cases the backs returned by Flux.pullback are different. In the second case, not even the loss value returned by Flux.pullback match.

Note that I'm using EnsembleSerial() for the EnsembleProblem, so parallelism has nothing to do with this issue.

(issue#54) pkg> st
      Status `~/Documents/GOKU_experiments/issue#54/Project.toml`
  [41bf760c] DiffEqSensitivity v6.58.0
  [587475ba] Flux v0.12.6
  [5e00f16f] LatentDiffEq v0.2.5 ``
  [1dea7af3] OrdinaryDiffEq v5.64.0
  [9a3f8284] Random

Copy link
Owner Author

I reported the more general MWE in SciML/SciMLSensitivity.jl#611.

