Skip to content

Commit

Permalink
Add step learning rate scheduler (#2423)
Browse files Browse the repository at this point in the history
* Add step learning rate scheduler (#1198)

* Add step learning rate scheduler

* Add a test utility function for comparing outputs of two schedulers

* fixup! Add step learning rate scheduler (#1198)

* Assert instead of returning `Result`

* Warn about atypical values of initial LR and gamma
  • Loading branch information
towerpark authored Nov 18, 2024
1 parent a78597d commit 6d105ea
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 12 deletions.
29 changes: 17 additions & 12 deletions crates/burn-core/src/lr_scheduler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ pub(super) mod test_utils {
use super::*;
use crate::TestBackend;

// A small tolerance for learning rate comparisons. Depending on how learning rates are
// computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is
// used here.
const LOOSE_EPSILON: LearningRate = 1e-10;

pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)
where
I: IntoIterator<Item = LearningRate>,
S: LrScheduler,
{
// Depending on how learning rates are computed by the scheduler, floating-point arithmetic
// error might exceed f64::EPSILON, so we use a larger epsilon here.
const LOOSE_EPSILON: f64 = 1e-10;

expected_lrs
.into_iter()
.enumerate()
Expand Down Expand Up @@ -61,15 +62,19 @@ pub(super) mod test_utils {
scheduler = scheduler.load_record::<TestBackend>(rec);

// Validate that the scheduler resumes from where it left off.
(save_at_step..2 * save_at_step).for_each(|i| {
let expected = truth.step();
let lr = scheduler.step();
// The two schedulers run with the exact same settings and code,
// so the difference, if any, should be small enough to fit in f64::EPSILON.
compare_steps(&mut scheduler, &mut truth, save_at_step);
}

// Check if two schedulers produce the same learning rate sequences over the specified number of
// steps.
pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {
(0..num_steps).for_each(|i| {
let lr_a = a.step();
let lr_b = b.step();
assert!(
(lr - expected).abs() < f64::EPSILON,
"Scheduled learning rate {lr} is not approximately equal to the expected value \
{expected} at step {i}",
(lr_a - lr_b).abs() < LOOSE_EPSILON,
"The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \
sequences are not approximately equal",
);
});
}
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-core/src/lr_scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ pub mod exponential;
/// Cosine learning rate scheduler
pub mod cosine;

/// Step learning rate scheduler
pub mod step;

mod base;

pub use base::*;
210 changes: 210 additions & 0 deletions crates/burn-core/src/lr_scheduler/step.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use burn_tensor::backend::Backend;

use crate as burn;

use super::LrScheduler;
use crate::{config::Config, LearningRate};

/// The configuration for create a [step learning rate scheduler](StepLrScheduler).
///
/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until
/// the same value has been given for `step_size` times. Then it multiplies the learning rate by
/// `gamma` before repeating the process.
///
/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but
/// a warning log will be output for such a value in case of mistyping.
///
/// ## Notes
///
/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than
/// `i32::MAX + 1` times.
#[derive(Config)]
pub struct StepLrSchedulerConfig {
// The learning rate at the initial step.
initial_lr: LearningRate,
// The number of iterations over which the learning rate remains unchanged before the next
// update.
step_size: usize,
/// The factor by which the learning rate is multiplied with each update. Default: 0.1.
#[config(default = 0.1)]
gamma: f64,
}

impl StepLrSchedulerConfig {
/// Initializes a [step learning rate scheduler](StepLrScheduler).
///
/// # Panics
///
/// Panics if `step_size` is 0.
pub fn init(&self) -> StepLrScheduler {
assert!(self.step_size > 0, "Step size must be greater than 0.");

// Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful
// in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518).
if self.initial_lr <= 0.0 {
log::warn!(
"Initial learning rate value of {} is not a positive number. Ignore this warning \
if it is intended.",
self.initial_lr
);
}
if self.gamma <= 0.0 || self.gamma >= 1.0 {
log::warn!(
"Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
intended.",
self.gamma
);
}

StepLrScheduler {
init_lr: self.initial_lr,
step_size: self.step_size,
gamma: self.gamma,
iter_idx: -1,
}
}
}

/// Step learning rate scheduler.
#[derive(Clone, Debug)]
pub struct StepLrScheduler {
init_lr: LearningRate,
step_size: usize,
gamma: f64,
// The index of the current iteration.
// `i32` is used for avoiding truncating the exponent when taking powers of `gamma`.
iter_idx: i32,
}

impl LrScheduler for StepLrScheduler {
type Record<B: Backend> = i32;

fn step(&mut self) -> LearningRate {
self.iter_idx = self
.iter_idx
.checked_add(1)
.expect("`.step()` should be called no more than `i32::MAX + 1` times");
// Type casting below causes no truncation, as all the values fall within the ranges.
self.init_lr
* self
.gamma
.powi((self.iter_idx as usize / self.step_size) as i32)
}

fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.iter_idx
}

fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.iter_idx = record;
self
}
}

#[cfg(test)]
mod tests {
use super::super::test_utils;
use super::*;
use crate::TestBackend;

// Warning logs for initial LR and gamma are not tested because there seems no straightforward
// way to do it.
//
// Creating a mock logger that collects logs into `String` for later examination seems a possible
// solution, but unit tests run in the same process in parallel, where the single logger would
// be shared by multiple tests, so logs from different tests would be mixed up with no easy way
// to separate them.
// Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is
// worth the slowdown would be a question. Also, using a primitive provided by `std` to
// synchronize the logger across tests is not an option since we need to support `no-std`.
// Maybe the mocking approach can be reconsidered after we are given an option to run tests in
// separate processes like what the issue below is proposing:
// https://github.com/rust-lang/rust/issues/47506
//
// As a side note, a helper crate exists for the exact purpose:
// https://crates.io/crates/testing_logger
// but the crate has been unmaintained and using it would introduce another dependency.

#[test]
#[should_panic]
fn test_config_step_size_zero() {
StepLrSchedulerConfig::new(1.0, 0).init();
}

#[test]
fn test_config_step_size_nonzero() {
StepLrSchedulerConfig::new(1.0, 1).init();
}

#[test]
fn test_config_default_gamma() {
const INIT_LR: LearningRate = 0.4;
const STEP_SIZE: usize = 2;

let mut default = create_scheduler_unchecked(INIT_LR, STEP_SIZE, None);
let mut explicit = create_scheduler_unchecked(INIT_LR, STEP_SIZE, Some(0.1));
test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
}

#[test]
fn test_lr_decreasing() {
let scheduler = create_scheduler_unchecked(0.5, 3, Some(0.1));
let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}

#[test]
fn test_lr_increasing() {
let scheduler = create_scheduler_unchecked(0.1, 2, Some(2.0));
let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}

#[test]
fn test_lr_unchanging() {
let scheduler = create_scheduler_unchecked(3.1, 1, Some(1.0));
let expected_lrs = [3.1, 3.1, 3.1];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}

#[test]
fn test_save_and_load() {
const STEP_SIZE: usize = 10;

let scheduler = create_scheduler_unchecked(0.007, STEP_SIZE, Some(0.03));
test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
}

// It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that
// depends on private fields is used to implement the test.
#[test]
fn test_number_of_calls_within_limit() {
// Create a scheduler that has already run `i32::MAX` steps
let mut scheduler = create_scheduler_unchecked(0.1, 2, None);
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
scheduler.step();
}

#[test]
#[should_panic = "i32::MAX"]
fn test_number_of_calls_over_limit() {
// Create a scheduler that has already run `i32::MAX` steps
let mut scheduler = create_scheduler_unchecked(0.1, 2, None);
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
scheduler.step();
scheduler.step();
}

// Create a scheduler with valid parameters
fn create_scheduler_unchecked(
init_lr: LearningRate,
step_size: usize,
gamma: Option<f64>,
) -> StepLrScheduler {
let mut config = StepLrSchedulerConfig::new(init_lr, step_size);
if let Some(g) = gamma {
config = config.with_gamma(g);
}
config.init()
}
}

0 comments on commit 6d105ea

Please sign in to comment.