Skip to content

Runtime dispatch in train! #2113

Open
Open
@Vilin97

Description

Package Version

v0.13.7

Julia Version

1.8.2

OS / Environment

Windows 11

Describe the bug

I followed the example of fitting a straight line from the Flux's docs, and used JET to analyze train!. It found 19 possible runtime dispatch errors.

Steps to Reproduce

using Flux, JET
actual(x) = 4x + 2
x_train, x_test = hcat(0:5...), hcat(6:10...)
y_train, y_test = actual.(x_train), actual.(x_test)
predict = Dense(1 => 1)
loss_(x, y) = Flux.Losses.mse(predict(x), y);
opt = Descent()
data = [(x_train, y_train)]
parameters = Flux.params(predict)
Flux.train!(loss_, parameters, data, opt)  # [edit: qualify train!]
@report_opt Flux.train!(loss_, parameters, data, opt) # ═════ 19 possible errors found ═════ 
# runtime dispatch detected: isequal(%1::Any, v::Task)::Bool

Expected Results

I expected to find no errors with JET

Observed Results

I found runtime dispatch errors.

Relevant log output

julia> @report_opt train!(loss_, parameters, data, opt)
═════ 19 possible errors found ═════
┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:125 Flux.Optimise.:(var"#train!#36")(#38, #self#, loss, ps, data, opt)      
│┌ @ logging.jl:376 logger = Base.CoreLogging.current_logger_for_env(std_level, group, _module)
││┌ @ logging.jl:499 Base.CoreLogging.env_override_minlevel(group, _module)
│││┌ @ logging.jl:565 Base.moduleroot(_module)
││││┌ @ reflection.jl:45 Base.is_root_module(m)
│││││┌ @ lock.jl:221 lock(temp)
││││││┌ @ lock.jl:103 slowlock(rl)
│││││││┌ @ lock.jl:112 wait(c)
││││││││┌ @ condition.jl:126 Base.list_deletefirst!(ct.queue, ct)
│││││││││┌ @ linked_list.jl:145 isequal(h.value, val)
││││││││││┌ @ gcutils.jl:4 isequal(%1, v)
│││││││││││ runtime dispatch detected: isequal(%1::Any, v::Task)::Bool    
││││││││││└────────────────
││││││││┌ @ condition.jl:126 Base.list_deletefirst!(%45, %39)
│││││││││ runtime dispatch detected: Base.list_deletefirst!(%45::Any, %39::Task)::Any
││││││││└────────────────────
│││││┌ @ lock.jl:225 unlock(temp)
││││││┌ @ lock.jl:133 _unlock(rl)
│││││││┌ @ lock.jl:139 notifywaiters(rl)
││││││││┌ @ lock.jl:143  = notify(cond_wait)
│││││││││┌ @ condition.jl:142 #self#(c, Base.nothing)
││││││││││┌ @ condition.jl:142 Base.:(var"#notify#586")(true, false, #self#, c, arg)
│││││││││││┌ @ condition.jl:142 notify(c, arg, all, error)
││││││││││││┌ @ condition.jl:148 Core.kwfunc(schedule)(NamedTuple{(:error,)}(tuple(error)), schedule, t, arg)
│││││││││││││┌ @ task.jl:789 Base.:(var"#schedule#613")(error, _3, t, arg)
││││││││││││││┌ @ task.jl:793 %10(%11, t)
│││││││││││││││ runtime dispatch detected: %10::typeof(Base.list_deletefirst!)(%11::Any, t::Task)::Any
││││││││││││││└───────────────
│┌ @ logging.jl:364 Base.CoreLogging.logging_error(logger, level, _module, group, id, file, line, err, true)
││┌ @ logging.jl:463 (%51)
│││ runtime dispatch detected: ::NamedTuple{(:exception,)}(%51::Tuple{Tuple{Any, Vector{Union{Ptr{Nothing}, Base.InterpreterIP}}}})::NamedTuple{(:exception,), _A} where _A<:Tuple{Tuple{Any, Vector{Union{Ptr{Nothing}, Base.InterpreterIP}}}}
││└──────────────────
││┌ @ logging.jl:463 handle_message##kw(%52, Base.CoreLogging.handle_message, %37, Base.CoreLogging.Error, %44, %38, :logevent_error, %39, %40, %41)
│││ runtime dispatch detected: handle_message##kw(%52::NamedTuple{(:exception,), _A} where _A<:Tuple{Tuple{Any, Vector{Union{Ptr{Nothing}, Base.InterpreterIP}}}}, Base.CoreLogging.handle_message, %37::Base.CoreLogging.AbstractLogger, Base.CoreLogging.Error, %44::Union{LazyString, String}, %38::Module, :logevent_error, %39::Base.UUID, %40::String, %41::Int64)::Any     
││└──────────────────
│┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:131 Flux.Optimise.withgradient(#37, ps)
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:132 pullback(tuple(f), args...)
│││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:384 Zygote._pullback(cx, f)
││││┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:132 Zygote._pullback(ctx, Core._apply_iterate, iterate, Zygote._pullback(ctx, Zygote.literal_getfield, f, Val{:loss}())[1], Zygote._pullback(ctx, Flux.Optimise.batchmemaybe, Zygote._pullback(ctx, Zygote.literal_getfield, f, Val{:d}())[1])[1])
│││││┌ @ C:\Users\Math User\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 ZygoteRules.adjoint(tuple(__context__, 550, 551, f), args...)       
││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\lib.jl:203 Core._apply(tuple(Zygote._pullback, tuple(__context__, f)), args...)
│││││││┌ @ boot.jl:816 Core._apply_iterate(tuple(Main.Base.iterate), x...)
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 Zygote._pullback(ctx, Zygote.literal_getproperty, Flux.Losses, Val{:mse}())
│││││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\literal_getproperty.jl:83 Zygote._pullback(cx, Zygote.getproperty, x, :mse)
││││││││││┌ @ Base.jl:31 Zygote._pullback(ctx, Base.getfield, Base.getfield(args, 1), Base.getfield(args, 2))
│││││││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\lib.jl:244 Zygote.Val(field_name)
││││││││││││┌ @ essentials.jl:714 %1()
│││││││││││││ runtime dispatch detected: %1::Type{Val{_A}} where _A()::Val
││││││││││││└─────────────────────
│││││││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\lib.jl:244 Zygote._pullback(cx, Zygote.literal_getfield, x, %1)
││││││││││││ runtime dispatch detected: Zygote._pullback(cx::Zygote.Context{true}, Zygote.literal_getfield, x::Module, %1::Val)::Tuple{Any, Zygote.var"#2077#back#218"}
│││││││││││└───────────────────────────────────────────────────────────   
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 Zygote._pullback(ctx, %12, %1)
│││││││││ runtime dispatch detected: Zygote._pullback(ctx::Zygote.Context{true}, %12::Any, %1::Matrix{Int64})::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %21[1]
│││││││││ runtime dispatch detected: (%21::Any)[1]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %21[2]
│││││││││ runtime dispatch detected: (%21::Any)[2]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 Zygote._pullback(ctx, %5, %22, %2)
│││││││││ runtime dispatch detected: Zygote._pullback(ctx::Zygote.Context{true}, %5::Any, %22::Any, %2::Matrix{Int64})::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %24[1]
│││││││││ runtime dispatch detected: (%24::Any)[1]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %24[2]
│││││││││ runtime dispatch detected: (%24::Any)[2]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:133 grad = back(Zygote.sensitivity(y))
│││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:390 Zygote.Grads(getfield(#self#, :cx).cache, getfield(#self#, :ps))
││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:281
 convert(, grads)
│││││ runtime dispatch detected: convert(::IdDict{Any, Any}, grads::Nothing)
││││└──────────────────────────────────────────────────────────────────────
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:133 Zygote.sensitivity(%5)
│││ runtime dispatch detected: Zygote.sensitivity(%5::Any)::Any
││└────────────────────────────────────────────────────────────────────── 
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:133 %10(%11)
│││ runtime dispatch detected: %10::Zygote.var"#99#100"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, _A, Zygote.Context{true}} where _A(%11::Any)::Zygote.Grads
││└────────────────────────────────────────────────────────────────────── 
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:135 (%13)
│││ runtime dispatch detected: ::NamedTuple{(:val, :grad)}(%13::Tuple{Any, Zygote.Grads})::NamedTuple{(:val, :grad), _A} where _A<:Tuple{Any, Zygote.Grads}
││└────────────────────────────────────────────────────────────────────── 
│┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:135 string("Loss is ", l, " on data item ", i, ", stopping training")      
││┌ @ strings/io.jl:185 Base.print_to_string(xs...)
│││┌ @ strings/io.jl:144 print(%83, %86)
││││ runtime dispatch detected: print(%83::IOBuffer, %86::Any)::Any       
│││└─────────────────────
│┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:137 update!(opt, ps, gs)
││┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:24 update!(opt, %32, %65)
│││ runtime dispatch detected: update!(opt::Descent, %32::Any, %65::Any)::Any
││└──────────────────────────────────────────────────────────────────────────

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions