Skip to content

Commit

Permalink
update comments, fix BFGS. Looks like API is broken SciML/Optimizatio…
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Oct 3, 2024
1 parent 4a9184a commit bf0d426
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
22 changes: 12 additions & 10 deletions src/train/backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,30 @@ function train_loop!(
)
@unpack __loader = loaders
@unpack lossfun, opt_args, opt_iter = trainer
@unpack io, verbose = trainer
@unpack io, verbose, device = trainer

# batch = first(__loader)
batch = if __loader isa CuIterator
# Adapt.adapt(__loader, __loader.batches.data)
__loader.batches.data |> cu
batch = if __loader isa MLDataDevices.DeviceIterator
__loader.iterator.data |> device
else
__loader.data
end

# https://github.com/SciML/Optimization.jl/issues/839

### TODO: using old st in BFGS
function optloss(optx, optp)
lossfun(state.NN, optx, state.st, batch)
lossfun(state.NN, optx, state.st, batch)[1] # l, st, stats
end

function optcb(optx, l, st, stats)
function optcb(optx, l) # optx, l, st, stats
evaluate(trainer, state, loaders)
state = TrainState(state.NN, optx.u, Lux.trainmode(st), state.opt_st)
# state = TrainState(state.NN, optx.u, Lux.trainmode(st), state.opt_st)
@set! state.p = optx.u

if !isempty(stats) & verbose
println(io, stats)
end
# if !isempty(stats) & verbose
# println(io, stats)
# end

opt_iter.epoch[] += 1
opt_iter.epoch_dt[] = time() - opt_iter.epoch_time[] - opt_iter.start_time[]
Expand Down
7 changes: 3 additions & 4 deletions src/train/trainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ abstract type AbstractTrainState end
opt_st
end

# rm once https://github.com/FluxML/Optimisers.jl/pull/180
# is merged
Adapt.@adapt_structure Optimisers.Leaf

function Adapt.adapt_structure(to, state::TrainState)
p = Adapt.adapt_structure(to, state.p )
st = Adapt.adapt_structure(to, state.st)
Expand All @@ -26,6 +22,9 @@ function Adapt.adapt_structure(to, state::TrainState)
TrainState(state.NN, p, st, opt_st)
end

# rm once https://github.com/FluxML/Optimisers.jl/pull/180 is merged
Adapt.@adapt_structure Optimisers.Leaf

#===============================================================#
abstract type AbstractTrainer end

Expand Down

0 comments on commit bf0d426

Please sign in to comment.