Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit train!, unify update!, and auto-translate the two Adams #2082

Merged
merged 18 commits into from
Nov 20, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove 2 vs 3 argument comment from docstring
  • Loading branch information
mcabbott committed Nov 20, 2022
commit db2a9b90584c8fe5b4b511e2478ea350f4b9b2bf
14 changes: 6 additions & 8 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ using Zygote: Zygote, Params

This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
It differs from `Optimisers.setup` in that it:
* has one extra check for mutability
* has one extra check for mutability (since Flux expects to mutate the model in-place,
while Optimisers.jl is designed to return an updated model)
* has methods which accept Flux's old optimisers, and convert them.
(The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.)

# Example
```jldoctest
Expand Down Expand Up @@ -80,16 +82,12 @@ It adds only a few featurs to the loop above:

* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).

Note that the built-in loss functions accept 3 arguments, allowing for instance
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.

!!! note
This method has significant changes from the one in Flux ≤ 0.13:
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
* Instead of `loss` being a function which typically accepts two arguments
(the input `x` and expected output `y` from each element of `data`)
now it should typically accept three, the first of which is the `model` itself.
* Instead of `loss` being a function which accepts only the data,
now it must also accept the `model` itself, as the first argument.
* `data` must iterate tuples, otherwise you get an error.
(Previously non-tuple types were not splatted into the loss.
Pass in `((d,) for d in data)` to simulate this.)
Expand Down Expand Up @@ -130,4 +128,4 @@ function _rule_to_state(model, rule::Optimisers.AbstractRule)
state
end

end # module
end # module Train