Skip to content

Commit

Permalink
Fix training checkpoints (tracel-ai#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 21, 2023
1 parent ac4adb5 commit aacf191
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 25 deletions.
6 changes: 3 additions & 3 deletions burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ where
if self.log_to_file {
self.init_logger();
}
let renderer = self
.renderer
.unwrap_or_else(|| Box::new(default_renderer(self.interrupter.clone())));
let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
});
let directory = &self.directory;
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
Expand Down
2 changes: 1 addition & 1 deletion burn-train/src/learner/train_val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ where
let starting_epoch = match self.checkpoint {
Some(checkpoint) => {
self = self.load_checkpoint(checkpoint);
checkpoint
checkpoint + 1
}
None => 1,
};
Expand Down
7 changes: 5 additions & 2 deletions burn-train/src/metric/dashboard/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ pub use tui::TuiDashboardRenderer as SelectedDashboardRenderer;

/// The TUI renderer, or a simple stub if the tui feature is not enabled.
#[allow(unused_variables)]
pub(crate) fn default_renderer(interuptor: TrainingInterrupter) -> SelectedDashboardRenderer {
pub(crate) fn default_renderer(
interuptor: TrainingInterrupter,
checkpoint: Option<usize>,
) -> SelectedDashboardRenderer {
#[cfg(feature = "tui")]
return SelectedDashboardRenderer::new(interuptor);
return SelectedDashboardRenderer::new(interuptor, checkpoint);

#[cfg(not(feature = "tui"))]
return SelectedDashboardRenderer::new();
Expand Down
64 changes: 47 additions & 17 deletions burn-train/src/metric/dashboard/tui/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,29 @@ use std::time::Instant;
///
/// We currently ignore the time taken for the validation part.
pub(crate) struct ProgressBarState {
progress_train: f64,
progress_train: f64, // Progress for total training.
progress_train_for_eta: f64, // Progress considering the starting epoch.
starting_epoch: usize,
started: Instant,
}

impl Default for ProgressBarState {
fn default() -> Self {
Self {
progress_train: 0.0,
started: Instant::now(),
}
}
}

const MINUTE: u64 = 60;
const HOUR: u64 = 60 * 60;
const DAY: u64 = 24 * 60 * 60;

impl ProgressBarState {
pub fn new(checkpoint: Option<usize>) -> Self {
Self {
progress_train: 0.0,
progress_train_for_eta: 0.0,
started: Instant::now(),
starting_epoch: checkpoint.unwrap_or(0),
}
}
/// Update the training progress.
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
let total_items = progress.progress.items_total * progress.epoch_total;
let epoch_items = (progress.epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed as f64;

self.progress_train = (epoch_items as f64 + iteration_items) / total_items as f64
self.progress_train = calculate_progress(progress, 0);
self.progress_train_for_eta = calculate_progress(progress, self.starting_epoch);
}

/// Update the validation progress.
Expand All @@ -47,10 +45,10 @@ impl ProgressBarState {
/// Create a view for the current progress.
pub(crate) fn view(&self) -> ProgressBarView {
let eta = self.started.elapsed();
let total_estimated = (eta.as_secs() as f64) / self.progress_train;
let total_estimated = (eta.as_secs() as f64) / self.progress_train_for_eta;

let eta = if total_estimated.is_normal() {
let remaining = 1.0 - self.progress_train;
let remaining = 1.0 - self.progress_train_for_eta;
let eta = (total_estimated * remaining) as u64;
format_eta(eta)
} else {
Expand Down Expand Up @@ -106,6 +104,17 @@ impl ProgressBarView {
}
}

fn calculate_progress(progress: &TrainingProgress, starting_epoch: usize) -> f64 {
let epoch_total = progress.epoch_total - starting_epoch;
let epoch = progress.epoch - starting_epoch;

let total_items = progress.progress.items_total * epoch_total;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed as f64;

(epoch_items as f64 + iteration_items) / total_items as f64
}

fn format_eta(eta_secs: u64) -> String {
let seconds = eta_secs % 60;
let minutes = eta_secs / MINUTE % 60;
Expand All @@ -130,6 +139,7 @@ fn format_eta(eta_secs: u64) -> String {
#[cfg(test)]
mod tests {
use super::*;
use burn_core::data::dataloader::Progress;

#[test]
fn test_format_eta() {
Expand All @@ -141,4 +151,24 @@ mod tests {
assert_eq!("1 days", format_eta(24 * 3601), "More than 1 day");
assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day");
}

#[test]
fn calculate_progress_for_eta() {
let half = Progress {
items_processed: 5,
items_total: 10,
};
let progress = TrainingProgress {
progress: half,
epoch: 9,
epoch_total: 10,
iteration: 500,
};

let starting_epoch = 8;
let progress = calculate_progress(&progress, starting_epoch);

// Two epochs remaining while the first is half done.
assert_eq!(0.25, progress);
}
}
4 changes: 2 additions & 2 deletions burn-train/src/metric/dashboard/tui/renderer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl DashboardRenderer for TuiDashboardRenderer {

impl TuiDashboardRenderer {
/// Create a new CLI dashboard renderer.
pub fn new(interuptor: TrainingInterrupter) -> Self {
pub fn new(interuptor: TrainingInterrupter, checkpoint: Option<usize>) -> Self {
let mut stdout = io::stdout();
execute!(stdout, EnterAlternateScreen).unwrap();
enable_raw_mode().unwrap();
Expand All @@ -88,7 +88,7 @@ impl TuiDashboardRenderer {
Self {
terminal,
last_update: Instant::now(),
progress: ProgressBarState::default(),
progress: ProgressBarState::new(checkpoint),
metrics_numeric: NumericMetricsState::default(),
metrics_text: TextMetricsState::default(),
status: StatusState::default(),
Expand Down

0 comments on commit aacf191

Please sign in to comment.