Skip to content

Commit

Permalink
Make .init() method of LR schedulers return Result (tracel-ai#2527)
Browse files Browse the repository at this point in the history
* Make `.init()` return `Result` instead of panicking for all the
  schedulers except `ConstantLr`, which has nothing to check.

* Update affected examples

* Add check and test code for Noam LR scheduler

* Clean up test code
  • Loading branch information
towerpark authored Nov 25, 2024
1 parent 9408d5f commit 23a6504
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 143 deletions.
2 changes: 2 additions & 0 deletions crates/burn-core/src/lr_scheduler/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub(super) use alloc::string::String;

use burn_tensor::backend::Backend;

use crate::{record::Record, LearningRate};
Expand Down
100 changes: 67 additions & 33 deletions crates/burn-core/src/lr_scheduler/cosine.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::LrScheduler;
use super::{LrScheduler, String};
use crate as burn;
use crate::{config::Config, LearningRate};
use burn_tensor::backend::Backend;
Expand All @@ -24,28 +24,34 @@ pub struct CosineAnnealingLrSchedulerConfig {
impl CosineAnnealingLrSchedulerConfig {
/// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler).
///
/// # Panics
/// This function panics if `initial_lr` and `final_lr` are not between 0 and 1.
pub fn init(&self) -> CosineAnnealingLrScheduler {
assert!(
self.initial_lr > 0. && self.initial_lr <= 1.,
"Initial learning rate must be greater than 0 and at most 1"
);
assert!(
self.min_lr >= 0.0 && self.min_lr <= self.initial_lr,
"Minimum learning rate must be at least 0 and at most equal to the initial learning rate"
);
assert!(
self.num_iters > 0,
"Number of iterations must be at least 1"
);
/// # Errors
///
/// An error will be returned if any of the following conditions is true:
///
/// * `initial_lr` is out of range (0.0, 1.0]
/// * `min_lr` is out of range [0.0, `initial_lr`]
/// * `num_iters` is 0
pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {
if self.initial_lr <= 0. || self.initial_lr > 1. {
return Err("Initial learning rate must be greater than 0 and at most 1".into());
}
if self.min_lr < 0.0 || self.min_lr > self.initial_lr {
return Err(
"Minimum learning rate must be at least 0 and at most equal to the initial \
learning rate"
.into(),
);
}
if self.num_iters == 0 {
return Err("Number of iterations must be at least 1".into());
}

CosineAnnealingLrScheduler {
Ok(CosineAnnealingLrScheduler {
min_lr: self.min_lr,
max_lr: self.initial_lr,
num_iters: self.num_iters,
current_iter: usize::MAX,
}
})
}
}

Expand Down Expand Up @@ -94,48 +100,75 @@ mod tests {
use super::*;

#[test]
#[should_panic = "Initial learning rate must be greater than 0 and at most 1"]
fn config_initial_lr_too_low() {
CosineAnnealingLrSchedulerConfig::new(0., 10).init();
let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}

#[test]
#[should_panic = "Initial learning rate must be greater than 0 and at most 1"]
fn config_initial_lr_too_high() {
CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}

#[test]
#[should_panic = "Minimum learning rate must be at least 0 and at most equal to the initial learning rate"]
fn config_min_lr_too_low() {
CosineAnnealingLrSchedulerConfig::new(0.5, 10)
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
.with_min_lr(-0.1)
.init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Minimum learning rate must be at least 0 and at most equal to the initial learning \
rate",
"Error messages should match",
);
}

#[test]
#[should_panic = "Minimum learning rate must be at least 0 and at most equal to the initial learning rate"]
fn config_min_lr_too_high() {
CosineAnnealingLrSchedulerConfig::new(0.5, 10)
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)
.with_min_lr(0.6)
.init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Minimum learning rate must be at least 0 and at most equal to the initial learning \
rate",
"Error messages should match",
);
}

#[test]
#[should_panic = "Number of iterations must be at least 1"]
fn config_num_iters_too_low() {
CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Number of iterations must be at least 1",
"Error messages should match",
);
}

#[test]
fn test_lr_change() {
const INITIAL_LR: LearningRate = 0.5;
const MIN_LR: LearningRate = 0.1;
const NUM_ITERS: usize = 2;

let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, NUM_ITERS)
let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)
.with_min_lr(MIN_LR)
.init();
.init()
.unwrap();
let expected_lrs = [
INITIAL_LR, // cos(0)
(INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2)
Expand All @@ -147,9 +180,10 @@ mod tests {

#[test]
fn test_save_and_load() {
const INITIAL_LR: LearningRate = 1.0;
const NUM_ITERS: usize = 9;
let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, NUM_ITERS).init();
let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)
.init()
.unwrap();
test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);
}
}
77 changes: 48 additions & 29 deletions crates/burn-core/src/lr_scheduler/exponential.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::LrScheduler;
use super::{LrScheduler, String};
use crate as burn;
use crate::{config::Config, LearningRate};
use burn_tensor::backend::Backend;
Expand All @@ -19,24 +19,26 @@ pub struct ExponentialLrSchedulerConfig {
impl ExponentialLrSchedulerConfig {
/// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler).
///
/// # Panics
/// This function panics if `initial_lr` and `gamma` are not between 0 and 1.
pub fn init(&self) -> ExponentialLrScheduler {
assert!(
self.initial_lr > 0. && self.initial_lr <= 1.,
"Initial learning rate must be greater than 0 and at most 1"
);
assert!(
self.gamma > 0. && self.gamma <= 1.,
"Gamma must be greater than 0 and at most 1"
);
/// # Errors
///
/// An error will be returned if any of the following conditions is true:
///
/// * `initial_lr` is out of range (0.0, 1.0]
/// * `gamma` is out of range (0.0, 1.0]
pub fn init(&self) -> Result<ExponentialLrScheduler, String> {
if self.initial_lr <= 0. || self.initial_lr > 1. {
return Err("Initial learning rate must be greater than 0 and at most 1".into());
}
if self.gamma <= 0. || self.gamma > 1. {
return Err("Gamma must be greater than 0 and at most 1".into());
}

ExponentialLrScheduler {
Ok(ExponentialLrScheduler {
// Such an initial value eliminates the need for special-case handling of the first
// learning rate.
previous_lr: self.initial_lr / self.gamma,
gamma: self.gamma,
}
})
}
}

Expand Down Expand Up @@ -75,44 +77,61 @@ mod tests {
use super::*;

#[test]
#[should_panic = "Initial learning rate must be greater than 0 and at most 1"]
fn config_initial_lr_too_low() {
ExponentialLrSchedulerConfig::new(0., 0.5).init();
let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}

#[test]
#[should_panic = "Initial learning rate must be greater than 0 and at most 1"]
fn config_initial_lr_too_high() {
ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Initial learning rate must be greater than 0 and at most 1",
"Error messages should match",
);
}

#[test]
#[should_panic = "Gamma must be greater than 0 and at most 1"]
fn config_gamma_too_low() {
ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Gamma must be greater than 0 and at most 1",
"Error messages should match",
);
}

#[test]
#[should_panic = "Gamma must be greater than 0 and at most 1"]
fn config_gamma_too_high() {
ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();
assert!(r.is_err(), "Should return an error");
assert_eq!(
r.unwrap_err(),
"Gamma must be greater than 0 and at most 1",
"Error messages should match",
);
}

#[test]
fn test_lr_change() {
const INITIAL_LR: LearningRate = 0.8;
const GAMMA: f64 = 0.1;

let scheduler = ExponentialLrSchedulerConfig::new(INITIAL_LR, GAMMA).init();
let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();
let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}

#[test]
fn test_save_and_load() {
const INITIAL_LR: LearningRate = 0.083;
const GAMMA: f64 = 0.3;
let scheduler = ExponentialLrSchedulerConfig::new(INITIAL_LR, GAMMA).init();
let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)
.init()
.unwrap();
test_utils::check_save_load(scheduler, 7);
}
}
Loading

0 comments on commit 23a6504

Please sign in to comment.