Linux | Coverage |
---|---|
A small package for applying early stopping criteria to loss-generating iterative algorithms, with a view to training and optimizing machine learning models.
The basis of IterationControl.jl, a package externally controlling iterative algorithms.
Includes the stopping criteria surveyed in Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.
using Pkg
Pkg.add("EarlyStopping")
The EarlyStopper
objects defined in this package consume a sequence
of numbers called losses generated by some external algorithm -
generally the training loss or out-of-sample loss of some iterative
statistical model - and decide when those losses have dropped
sufficiently to warrant terminating the algorithm. A number of
commonly applied stopping criteria, listed under
Criteria below, are provided out-of-the-box.
Here's an example of using an EarlyStopper
object to check against
two of these criteria (either triggering the stop):
using EarlyStopping
stopper = EarlyStopper(Patience(2), InvalidValue()) # multiple criteria
done!(stopper, 0.123) # false
done!(stopper, 0.234) # false
done!(stopper, 0.345) # true
julia> message(stopper)
"Early stop triggered by Patience(2) stopping criterion. "
One may force an EarlyStopper
to report its evolving state:
losses = [10.0, 11.0, 10.0, 11.0, 12.0, 10.0];
stopper = EarlyStopper(Patience(2), verbosity=1);
for loss in losses
done!(stopper, loss) && break
end
[ Info: loss: 10.0 state: (loss = 10.0, n_increases = 0)
[ Info: loss: 11.0 state: (loss = 11.0, n_increases = 1)
[ Info: loss: 10.0 state: (loss = 10.0, n_increases = 0)
[ Info: loss: 11.0 state: (loss = 11.0, n_increases = 1)
[ Info: loss: 12.0 state: (loss = 12.0, n_increases = 2)
The "object-oriented" interface demonstrated here is not code-optimized but will suffice for the majority of use-cases. For performant code, use the functional interface described under Implementing new criteria below.
To list all stopping criterion, do subtypes(StoppingCriterion)
. Each
subtype T
has a detailed doc-string queried with ?T
at the
REPL. Here is a short summary:
criterion | description | notation in Prechelt |
---|---|---|
Never() |
Never stop | |
InvalidValue() |
Stop when NaN , Inf or -Inf encountered |
|
TimeLimit(t=0.5) |
Stop after t hours |
|
NumberLimit(n=100) |
Stop after n loss updates (excl. "training losses") |
|
NumberSinceBest(n=6) |
Stop after n loss updates (excl. "training losses") |
|
Threshold(value=0.0) |
Stop when loss < value |
|
GL(alpha=2.0) |
Stop after "Generalization Loss" exceeds alpha |
GL_α |
PQ(alpha=0.75, k=5) |
Stop after "Progress-modified GL" exceeds alpha |
PQ_α |
Patience(n=5) |
Stop after n consecutive loss increases |
UP_s |
Disjunction(c...) |
Stop when any of the criteria c apply |
|
Warmup(c; n=1) |
Wait for n loss updates before checking criteria c |
For criteria tracking both an "out-of-sample" loss and a "training"
loss (eg, stopping criterion of type PQ
), specify training=true
if
the update is for training, as in
done!(stopper, 0.123, training=true)
In these cases, the out-of-sample update must always come after the corresponding training update. Multiple training updates may precede the out-of-sample update, as in the following example:
criterion = PQ(alpha=2.0, k=2)
needs_training_losses(criterion) # true
stopper = EarlyStopper(criterion)
done!(stopper, 9.5, training=true) # false
done!(stopper, 9.3, training=true) # false
done!(stopper, 10.0) # false
done!(stopper, 9.3, training=true) # false
done!(stopper, 9.1, training=true) # false
done!(stopper, 8.9, training=true) # false
done!(stopper, 8.0) # false
done!(stopper, 8.3, training=true) # false
done!(stopper, 8.4, training=true) # false
done!(stopper, 9.0) # true
Important. If there is no distinction between in and out-of-sample
losses, then any criterion can be applied, and in that case training=true
is never specified (regardless of the actual interpretation of the
losses being tracked).
To determine the stopping time for an iterator losses
, use
stopping_time(criterion, losses)
. This is useful for debugging new
criteria (see below). If the iterator terminates without a stop, 0
is returned.
julia> stopping_time(InvalidValue(), [10.0, 3.0, Inf, 4.0])
3
julia> stopping_time(Patience(3), [10.0, 3.0, 4.0, 5.0], verbosity=1)
[ Info: loss updates: 1
[ Info: state: (loss = 10.0, n_increases = 0)
[ Info: loss updates: 2
[ Info: state: (loss = 3.0, n_increases = 0)
[ Info: loss updates: 3
[ Info: state: (loss = 4.0, n_increases = 1)
[ Info: loss updates: 4
[ Info: state: (loss = 5.0, n_increases = 2)
0
If the losses include both training and out-of-sample losses as
described above, pass an extra Bool
vector marking the training
losses with true
, as in
stopping_time(PQ(),
[0.123, 0.321, 0.52, 0.55, 0.56, 0.58],
[true, true, false, true, true, false])
To implement a new stopping criterion, one must:
- Define a new
struct
for the criterion, which must subtypeStoppingCriterion
. - Overload methods
update
anddone
for the new type.
struct NewCriteria <: StoppingCriterion
# Put relevant fields here
end
# Provide a default constructor with all key-word arguments
NewCriteria(; kwargs...) = ...
# Return the initial state of the NewCriteria after
# receiving an out-of-sample loss
update(c::NewCriteria, loss, ::Nothing) = ...
# Return an updated state for NewCriteria given a `loss`
# and the current `state`
update(c::NewCriteria, loss, state) = ...
# Return true if NewCriteria should stop given `state`.
# Always return false if `state === nothing`
done(c::NewCriteria, state) = state === nothing ? false : ....
Optionally, one may define the following:
- Overload the final message with
message
. - Handle training losses by overloading
update_training
and the traitneeds_training_losses
.
# Final message when NewCriteria triggers a stop
message(c::NewCriteria, state) = ...
# Methods for initializing/updating the state given a training loss
update_training(c::NewCriteria, loss, ::Nothing) = ...
update_training(c::NewCriteria, loss, state) = ...
Wrappers. If your criterion wraps another criterion (as Warmup
does) then the criterion
must be a field and must store the
criterion being wrapped.
We demonstrate this with a simplified version of the
code for Patience
:
using EarlyStopping
struct Patience <: StoppingCriterion
n::Int
end
Patience(; n=5) = Patience(n)
All information to be "remembered" must passed around in an object
called state
below, which is the return value of update
(and
update_training
). The update
function has two methods:
- Initialization:
update(c::NewCriteria, loss, ::Nothing)
- Subsequent Loss Updates:
update(c::NewCriteria, loss, state)
Where state
is the return of the previous call to update
or update_training
.
Notice, that state === nothing
indicates an uninitialized criteria.
import EarlyStopping: update, done
function update(criterion::Patience, loss, ::Nothing)
return (loss=loss, n_increases=0) # state
end
function update(criterion::Patience, loss, state)
old_loss, n = state
if loss > old_loss
n += 1
else
n = 0
end
return (loss=loss, n_increases=n) # state
end
The done
method returns true
or false
depending on the state
, but
always returns false
if state === nothing
.
done(criterion::Patience, state) =
state === nothing ? false : state.n_increases == criterion.n
The final message of an EarlyStopper
is generated using a message
method for StoppingCriterion
. Here is the fallback (which does not
use state
):
EarlyStopping.message(criteria::StoppingCriterion, state)
= "Early stop triggered by $criterion stopping criterion. "
The optional update_training
methods (two for each criterion) have
the same signature as the update
methods above. Refer to the PQ
code for an example.
If a stopping criterion requires one or more update_training
calls
per update
call to work, you should overload the trait
needs_training_losses
for that type, as in this example from
the source code:
EarlyStopping.needs_training_losses(::Type{<:PQ}) = true
The following are provided to facilitate testing of new criteria:
stopping_time
: returns the stopping time for an iteratorlosses
usingcriterion
.@test_criteria NewCriteria()
: Runs a suite of unit tests against the providedStoppingCriteria
. This macro is only part of the test suite and is not part of the API.