-
Notifications
You must be signed in to change notification settings - Fork 469
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add step learning rate scheduler (#2423)
* Add step learning rate scheduler (#1198) * Add step learning rate scheduler * Add a test utility function for comparing outputs of two schedulers * fixup! Add step learning rate scheduler (#1198) * Assert instead of returning `Result` * Warn about atypical values of initial LR and gamma
- Loading branch information
Showing
3 changed files
with
230 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
use burn_tensor::backend::Backend; | ||
|
||
use crate as burn; | ||
|
||
use super::LrScheduler; | ||
use crate::{config::Config, LearningRate}; | ||
|
||
/// The configuration for create a [step learning rate scheduler](StepLrScheduler). | ||
/// | ||
/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until | ||
/// the same value has been given for `step_size` times. Then it multiplies the learning rate by | ||
/// `gamma` before repeating the process. | ||
/// | ||
/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but | ||
/// a warning log will be output for such a value in case of mistyping. | ||
/// | ||
/// ## Notes | ||
/// | ||
/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than | ||
/// `i32::MAX + 1` times. | ||
#[derive(Config)] | ||
pub struct StepLrSchedulerConfig { | ||
// The learning rate at the initial step. | ||
initial_lr: LearningRate, | ||
// The number of iterations over which the learning rate remains unchanged before the next | ||
// update. | ||
step_size: usize, | ||
/// The factor by which the learning rate is multiplied with each update. Default: 0.1. | ||
#[config(default = 0.1)] | ||
gamma: f64, | ||
} | ||
|
||
impl StepLrSchedulerConfig { | ||
/// Initializes a [step learning rate scheduler](StepLrScheduler). | ||
/// | ||
/// # Panics | ||
/// | ||
/// Panics if `step_size` is 0. | ||
pub fn init(&self) -> StepLrScheduler { | ||
assert!(self.step_size > 0, "Step size must be greater than 0."); | ||
|
||
// Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful | ||
// in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518). | ||
if self.initial_lr <= 0.0 { | ||
log::warn!( | ||
"Initial learning rate value of {} is not a positive number. Ignore this warning \ | ||
if it is intended.", | ||
self.initial_lr | ||
); | ||
} | ||
if self.gamma <= 0.0 || self.gamma >= 1.0 { | ||
log::warn!( | ||
"Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \ | ||
intended.", | ||
self.gamma | ||
); | ||
} | ||
|
||
StepLrScheduler { | ||
init_lr: self.initial_lr, | ||
step_size: self.step_size, | ||
gamma: self.gamma, | ||
iter_idx: -1, | ||
} | ||
} | ||
} | ||
|
||
/// Step learning rate scheduler. | ||
#[derive(Clone, Debug)] | ||
pub struct StepLrScheduler { | ||
init_lr: LearningRate, | ||
step_size: usize, | ||
gamma: f64, | ||
// The index of the current iteration. | ||
// `i32` is used for avoiding truncating the exponent when taking powers of `gamma`. | ||
iter_idx: i32, | ||
} | ||
|
||
impl LrScheduler for StepLrScheduler { | ||
type Record<B: Backend> = i32; | ||
|
||
fn step(&mut self) -> LearningRate { | ||
self.iter_idx = self | ||
.iter_idx | ||
.checked_add(1) | ||
.expect("`.step()` should be called no more than `i32::MAX + 1` times"); | ||
// Type casting below causes no truncation, as all the values fall within the ranges. | ||
self.init_lr | ||
* self | ||
.gamma | ||
.powi((self.iter_idx as usize / self.step_size) as i32) | ||
} | ||
|
||
fn to_record<B: Backend>(&self) -> Self::Record<B> { | ||
self.iter_idx | ||
} | ||
|
||
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self { | ||
self.iter_idx = record; | ||
self | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::super::test_utils; | ||
use super::*; | ||
use crate::TestBackend; | ||
|
||
// Warning logs for initial LR and gamma are not tested because there seems no straightforward | ||
// way to do it. | ||
// | ||
// Creating a mock logger that collects logs into `String` for later examination seems a possible | ||
// solution, but unit tests run in the same process in parallel, where the single logger would | ||
// be shared by multiple tests, so logs from different tests would be mixed up with no easy way | ||
// to separate them. | ||
// Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is | ||
// worth the slowdown would be a question. Also, using a primitive provided by `std` to | ||
// synchronize the logger across tests is not an option since we need to support `no-std`. | ||
// Maybe the mocking approach can be reconsidered after we are given an option to run tests in | ||
// separate processes like what the issue below is proposing: | ||
// https://github.com/rust-lang/rust/issues/47506 | ||
// | ||
// As a side note, a helper crate exists for the exact purpose: | ||
// https://crates.io/crates/testing_logger | ||
// but the crate has been unmaintained and using it would introduce another dependency. | ||
|
||
#[test] | ||
#[should_panic] | ||
fn test_config_step_size_zero() { | ||
StepLrSchedulerConfig::new(1.0, 0).init(); | ||
} | ||
|
||
#[test] | ||
fn test_config_step_size_nonzero() { | ||
StepLrSchedulerConfig::new(1.0, 1).init(); | ||
} | ||
|
||
#[test] | ||
fn test_config_default_gamma() { | ||
const INIT_LR: LearningRate = 0.4; | ||
const STEP_SIZE: usize = 2; | ||
|
||
let mut default = create_scheduler_unchecked(INIT_LR, STEP_SIZE, None); | ||
let mut explicit = create_scheduler_unchecked(INIT_LR, STEP_SIZE, Some(0.1)); | ||
test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE); | ||
} | ||
|
||
#[test] | ||
fn test_lr_decreasing() { | ||
let scheduler = create_scheduler_unchecked(0.5, 3, Some(0.1)); | ||
let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005]; | ||
test_utils::check_lr_sequence(scheduler, expected_lrs); | ||
} | ||
|
||
#[test] | ||
fn test_lr_increasing() { | ||
let scheduler = create_scheduler_unchecked(0.1, 2, Some(2.0)); | ||
let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4]; | ||
test_utils::check_lr_sequence(scheduler, expected_lrs); | ||
} | ||
|
||
#[test] | ||
fn test_lr_unchanging() { | ||
let scheduler = create_scheduler_unchecked(3.1, 1, Some(1.0)); | ||
let expected_lrs = [3.1, 3.1, 3.1]; | ||
test_utils::check_lr_sequence(scheduler, expected_lrs); | ||
} | ||
|
||
#[test] | ||
fn test_save_and_load() { | ||
const STEP_SIZE: usize = 10; | ||
|
||
let scheduler = create_scheduler_unchecked(0.007, STEP_SIZE, Some(0.03)); | ||
test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2); | ||
} | ||
|
||
// It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that | ||
// depends on private fields is used to implement the test. | ||
#[test] | ||
fn test_number_of_calls_within_limit() { | ||
// Create a scheduler that has already run `i32::MAX` steps | ||
let mut scheduler = create_scheduler_unchecked(0.1, 2, None); | ||
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1); | ||
scheduler.step(); | ||
} | ||
|
||
#[test] | ||
#[should_panic = "i32::MAX"] | ||
fn test_number_of_calls_over_limit() { | ||
// Create a scheduler that has already run `i32::MAX` steps | ||
let mut scheduler = create_scheduler_unchecked(0.1, 2, None); | ||
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1); | ||
scheduler.step(); | ||
scheduler.step(); | ||
} | ||
|
||
// Create a scheduler with valid parameters | ||
fn create_scheduler_unchecked( | ||
init_lr: LearningRate, | ||
step_size: usize, | ||
gamma: Option<f64>, | ||
) -> StepLrScheduler { | ||
let mut config = StepLrSchedulerConfig::new(init_lr, step_size); | ||
if let Some(g) = gamma { | ||
config = config.with_gamma(g); | ||
} | ||
config.init() | ||
} | ||
} |