Skip to content

Commit

Permalink
Add warmup logic when calculating eta (tracel-ai#923)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 3, 2023
1 parent 2ac348c commit dddc138
Showing 1 changed file with 107 additions and 22 deletions.
129 changes: 107 additions & 22 deletions burn-train/src/renderer/tui/progress.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
use crate::renderer::TrainingProgress;

use super::TerminalFrame;
use crate::renderer::TrainingProgress;
use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Gauge, Paragraph},
};
use std::time::Instant;
use std::time::{Duration, Instant};

/// Simple progress bar for the training.
///
/// We currently ignore the time taken for the validation part.
pub(crate) struct ProgressBarState {
progress_train: f64, // Progress for total training.
progress_train_for_eta: f64, // Progress considering the starting epoch.
progress_train: f64, // Progress for total training.
starting_epoch: usize,
started: Instant,
estimate: ProgressEstimate,
}

const MINUTE: u64 = 60;
Expand All @@ -27,15 +25,14 @@ impl ProgressBarState {
pub fn new(checkpoint: Option<usize>) -> Self {
Self {
progress_train: 0.0,
progress_train_for_eta: 0.0,
started: Instant::now(),
estimate: ProgressEstimate::new(),
starting_epoch: checkpoint.unwrap_or(0),
}
}
/// Update the training progress.
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
self.progress_train = calculate_progress(progress, 0);
self.progress_train_for_eta = calculate_progress(progress, self.starting_epoch);
self.progress_train = calculate_progress(progress, 0, 0);
self.estimate.update(progress, self.starting_epoch);
}

/// Update the validation progress.
Expand All @@ -45,15 +42,11 @@ impl ProgressBarState {

/// Create a view for the current progress.
pub(crate) fn view(&self) -> ProgressBarView {
let eta = self.started.elapsed();
let total_estimated = (eta.as_secs() as f64) / self.progress_train_for_eta;
const NO_ETA: &str = "---";

let eta = if total_estimated.is_normal() {
let remaining = 1.0 - self.progress_train_for_eta;
let eta = (total_estimated * remaining) as u64;
format_eta(eta)
} else {
"---".to_string()
let eta = match self.estimate.secs() {
Some(eta) => format_eta(eta),
None => NO_ETA.to_string(),
};
ProgressBarView::new(self.progress_train, eta)
}
Expand Down Expand Up @@ -105,15 +98,87 @@ impl ProgressBarView {
}
}

fn calculate_progress(progress: &TrainingProgress, starting_epoch: usize) -> f64 {
struct ProgressEstimate {
started: Instant,
started_after_warmup: Option<Instant>,
warmup_num_items: usize,
progress: f64,
}

impl ProgressEstimate {
fn new() -> Self {
Self {
started: Instant::now(),
started_after_warmup: None,
warmup_num_items: 0,
progress: 0.0,
}
}

fn secs(&self) -> Option<u64> {
let eta = match self.started_after_warmup {
Some(started) => started.elapsed(),
None => return None,
};

let total_estimated = (eta.as_secs() as f64) / self.progress;

if total_estimated.is_normal() {
let remaining = 1.0 - self.progress;
let eta = (total_estimated * remaining) as u64;
Some(eta)
} else {
None
}
}

fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
if self.started_after_warmup.is_some() {
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
return;
}

const WARMUP_NUM_ITERATION: usize = 10;

// When the training has started since 30 seconds.
if self.started.elapsed() > Duration::from_secs(30) {
self.init(progress, starting_epoch);
return;
}

// When the training has started since at least 10 seconds and completed 10 iterations.
if progress.iteration >= WARMUP_NUM_ITERATION
&& self.started.elapsed() > Duration::from_secs(10)
{
self.init(progress, starting_epoch);
}
}

fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) {
let epoch = progress.epoch - starting_epoch;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed;

self.warmup_num_items = epoch_items + iteration_items;
self.started_after_warmup = Some(Instant::now());
self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);
}
}

fn calculate_progress(
progress: &TrainingProgress,
starting_epoch: usize,
ignore_num_items: usize,
) -> f64 {
let epoch_total = progress.epoch_total - starting_epoch;
let epoch = progress.epoch - starting_epoch;

let total_items = progress.progress.items_total * epoch_total;
let epoch_items = (epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed as f64;
let iteration_items = progress.progress.items_processed;
let num_items = epoch_items + iteration_items - ignore_num_items;

(epoch_items as f64 + iteration_items) / total_items as f64
num_items as f64 / total_items as f64
}

fn format_eta(eta_secs: u64) -> String {
Expand Down Expand Up @@ -171,9 +236,29 @@ mod tests {
};

let starting_epoch = 8;
let progress = calculate_progress(&progress, starting_epoch);
let progress = calculate_progress(&progress, starting_epoch, 0);

// Two epochs remaining while the first is half done.
assert_eq!(0.25, progress);
}

#[test]
fn calculate_progress_for_eta_with_warmup() {
let half = Progress {
items_processed: 110,
items_total: 1000,
};
let progress = TrainingProgress {
progress: half,
epoch: 9,
epoch_total: 10,
iteration: 500,
};

let starting_epoch = 8;
let progress = calculate_progress(&progress, starting_epoch, 10);

// Two epochs remaining while the first is half done.
assert_eq!(0.05, progress);
}
}

0 comments on commit dddc138

Please sign in to comment.