Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing the EnsembleProblem with an equivalent simple for loop breaks the training #54

Open
gabrevaya opened this issue Sep 15, 2021 · 2 comments

Comments

@gabrevaya
Copy link
Owner

If we change the differential equations solving inside the diffeq_layer from the current EnsembleProblem to a simple for loop, which forward pass gives the exact same result, the whole training breaks. You can quickly try this by overriding the diffeq_layer function and running the pendulum example or tutorial:

import LatentDiffEq.Decoder
import LatentDiffEq.diffeq_layer
import LatentDiffEq.transform_after_diffeq

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs

    prob = remake(prob; tspan = (t[1],t[end]))
    sols = Array{Float32,2}[]

    for i in 1:size(ẑ₀,2)
        prob = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
        sol = solve(prob, solver; sensealg = sensealg, saveat = t, kwargs...)
        push!(sols, Array(sol))
    end= Flux.stack(sols, 2)
    ẑ = transform_after_diffeq(ẑ, decoder.diffeq)
    returnend

# Identity by default
transform_after_diffeq(x, diffeq) = x

For reference, this is the original diffeq_layer function:

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs

    # Function definition for ensemble problem
    prob_func(prob,i,repeat) = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])

    # Check if solve was successful, if not, return NaNs to avoid problems with dimensions
    output_func(sol, i) = sol.retcode == :Success ? (Array(sol), false) : (fill(NaN32,(size(ẑ₀, 1), length(t))), false)

    ## Adapt problem to given time span and create ensemble problem definition
    prob = remake(prob; tspan = (t[1], t[end]))
    ens_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func)

    ## Solve= solve(ens_prob, solver, EnsembleThreads(); sensealg = sensealg, trajectories = size(θ̂, 2), saveat = t, kwargs...)

    # Transform the resulting output (mainly used for Kuramoto-like systems)= transform_after_diffeq(ẑ, decoder.diffeq)
    ẑ = permutedims(ẑ, [1,3,2])

    returnend

The forward output is exactly the same however the training of the model malfunctions. After a few epochs, the loss function tends to greatly increase ( >> 1e6 after 50 epochs in the default pendulum example).

I'll try to get a MWE and compare the gradients. This might be hint of a sensitivity analysis, zygote or primitives issue, which could be the reason why the pendulum example in the Python implementation of GOKU nets converges faster and is more robust under different random seeds when using exactly the same architecture, hyperparameters and initializations.

Maybe I'm doing something wrong in terms of DiffEqFlux with the simple for loop version? @ChrisRackauckas, a priori do you see something wrong here?

@gabrevaya
Copy link
Owner Author

Here is a MWE in the context of LatentDiffEq.jl. Later I'll create a more general (minimal) MWE.

using LatentDiffEq
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Random
import LatentDiffEq.Decoder
import LatentDiffEq.diffeq_layer

Random.seed!(3)

struct Pendulum{P,S,T,K}
    prob::P
    solver::S
    sensealg::T
    kwargs::K
    function Pendulum(; kwargs...)
        # Parameters and initial conditions only
        # used to initialize the ODE problem
        u₀ = Float32[1.0, 1.0]
        p = Float32[1.]
        tspan = (0.f0, 1.f0)

        # Define differential equations
        function f!(du, u, p, t)
                x, y = u
                G = 10.0f0
                L = p[1]
                
                du[1] = y
                du[2] =  -G/L*sin(x)
        end

        # Build ODE Problem
        prob = ODEProblem(f!, u₀, tspan, p)

        # Chose a solver and sensitivity algorithm
        solver = Tsit5()
        sensalg = ForwardDiffSensitivity()
        # sensalg = BacksolveAdjoint()

        # sensalg = InterpolatingAdjoint()
        # sensalg = QuadratureAdjoint()
        
        P = typeof(prob)
        S = typeof(solver)
        T = typeof(sensalg)
        K = typeof(kwargs)
        new{P,S,T,K}(prob, solver, sensalg, kwargs)
    end 
end

model_type = GOKU()
diffeq = Pendulum()
input_dim = 784

encoder_layers, decoder_layers = default_layers(model_type, input_dim, diffeq)
model = LatentDiffEqModel(model_type, encoder_layers, decoder_layers)
ps = Flux.params(model)
l̂ = (rand(Float32, 2, 64), rand(Float32, 1, 64))
t = range(0.f0, step=0.05, length=50)
loss(model, l̂, t) = sum(diffeq_layer(model.decoder, l̂, t) .- 1)

function evaluate(model, l̂, t)
    out = diffeq_layer(model.decoder, l̂, t)
    l_direct = loss(model, l̂, t)

    l, back = Flux.pullback(ps) do
        loss(model, l̂, t)
    end

    return out, l_direct, l, back
end


function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs
    
    # Function definition for ensemble problem
    prob_func(prob,i,repeat) = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
    
    # Check if solve was successful, if not, return NaNs to avoid problems with dimensions matches
    output_func(sol, i) = sol.retcode == :Success ? (Array(sol), false) : (fill(NaN32,(size(ẑ₀, 1), length(t))), false)
    
    ## Adapt problem to given time span and create ensemble problem definition
    prob = remake(prob; tspan = (t[1],t[end]))
    ens_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func)
    
    ## Solve= solve(ens_prob, solver, EnsembleSerial(); sensealg = sensealg, trajectories = size(θ̂, 2), saveat = t, kwargs...)
    ẑ = permutedims(ẑ, [1,3,2])
    returnend

res1 = evaluate(model, l̂, t)

function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
    ẑ₀, θ̂ = l̂
    prob = decoder.diffeq.prob
    solver = decoder.diffeq.solver
    sensealg = decoder.diffeq.sensealg
    kwargs = decoder.diffeq.kwargs

    prob = remake(prob; tspan = (t[1],t[end]))
    sols = Array{Float32,2}[]

    for i in 1:size(ẑ₀,2)
        prob = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
        sol = solve(prob, solver; sensealg = sensealg, saveat = t, kwargs...)
        push!(sols, Array(sol))
    end= Flux.stack(sols, 2)
    returnend

res2 = evaluate(model, l̂, t)

@show res1 .== res2

@show res1[2]
@show res1[3]

@show res2[2]
@show res2[3]

If ForwardDiffSensitivity() or BacksolveAdjoint() are used as sensalg (chosen inside the Pendulum struct), the results are the following:

res1 .== res2 = (true, true, true, false)
res1[2] = -6916.1006f0
res1[3] = -6916.1006f0
res2[2] = -6916.1006f0
res2[3] = -6916.1006f0

If InterpolatingAdjoint() or QuadratureAdjoint() are used as sensalg:

res1 .== res2 = (true, true, false, false)
res1[2] = -6916.1006f0
res1[3] = -6916.099838162956
res2[2] = -6916.1006f0
res2[3] = -6916.1006f0

So in both cases the backs returned by Flux.pullback are different. In the second case, not even the loss value returned by Flux.pullback match.

Note that I'm using EnsembleSerial() for the EnsembleProblem, so parallelism has nothing to do with this issue.

(issue#54) pkg> st
      Status `~/Documents/GOKU_experiments/issue#54/Project.toml`
  [41bf760c] DiffEqSensitivity v6.58.0
  [587475ba] Flux v0.12.6
  [5e00f16f] LatentDiffEq v0.2.5 `https://github.com/gabrevaya/LatentDiffEq.jl.git#master`
  [1dea7af3] OrdinaryDiffEq v5.64.0
  [9a3f8284] Random

@gabrevaya
Copy link
Owner Author

I reported the more general MWE in SciML/SciMLSensitivity.jl#611.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant