Skip to content

Commit

Permalink
Feat/dashboard tui (tracel-ai#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 13, 2023
1 parent 4f72578 commit 57d6a56
Show file tree
Hide file tree
Showing 25 changed files with 1,599 additions and 472 deletions.
1 change: 1 addition & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[default]
extend-ignore-identifiers-re = [
"ratatui",
"NdArray*",
"ND"
]
Expand Down
13 changes: 9 additions & 4 deletions burn-core/src/data/dataloader/multithread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@ where
let mut iterator = dataloader_cloned.iter();
while let Some(item) = iterator.next() {
let progress = iterator.progress();
sender_cloned
.send(Message::Batch(index, item, progress))
.unwrap();

match sender_cloned.send(Message::Batch(index, item, progress)) {
Ok(_) => {}
// The receiver is probably gone, no need to panic, just need to stop
// iterating.
Err(_) => return,
};
}
sender_cloned.send(Message::Done).unwrap();
// Same thing.
sender_cloned.send(Message::Done).ok();
})
})
.collect();
Expand Down
16 changes: 6 additions & 10 deletions burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-train"
version = "0.10.0"

[features]
default = ["metrics", "ui"]
default = ["metrics", "tui"]
metrics = [
"nvml-wrapper",
"sysinfo",
"systemstat"
]
ui = [
"indicatif",
"rgb",
"terminal_size",
"textplots",
tui = [
"ratatui",
"crossterm"
]

[dependencies]
Expand All @@ -38,10 +36,8 @@ sysinfo = { version = "0.29.8", optional = true }
systemstat = { version = "0.2.3", optional = true }

# Text UI
indicatif = { version = "0.17.5", optional = true }
rgb = { version = "0.8.36", optional = true }
terminal_size = { version = "0.2.6", optional = true }
textplots = { version = "0.8.0", optional = true }
ratatui = { version = "0.23", optional = true, features = ["all-widgets"] }
crossterm = { version = "0.27", optional = true }

# Utilities
derive-new = {workspace = true}
Expand Down
7 changes: 4 additions & 3 deletions burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use super::log::install_file_logger;
use super::Learner;
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::dashboard::CLIDashboardRenderer;
use crate::metric::dashboard::{Dashboard, DashboardRenderer, MetricWrapper, Metrics};
use crate::metric::dashboard::{
Dashboard, DashboardRenderer, MetricWrapper, Metrics, SelectedDashboardRenderer,
};
use crate::metric::{Adaptor, Metric};
use crate::AsyncTrainerCallback;
use burn_core::lr_scheduler::LRScheduler;
Expand Down Expand Up @@ -259,7 +260,7 @@ where
}
let renderer = self
.renderer
.unwrap_or_else(|| Box::new(CLIDashboardRenderer::new()));
.unwrap_or_else(|| Box::new(SelectedDashboardRenderer::new(self.interrupter.clone())));
let directory = &self.directory;
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
Expand Down
6 changes: 6 additions & 0 deletions burn-train/src/learner/epoch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ impl<TI> TrainEpoch<TI> {

// The main device is always the first in the list.
let device_main = devices.get(0).unwrap().clone();
let mut interrupted = false;

loop {
let items = step.step(&mut iterator, &model);
Expand Down Expand Up @@ -234,9 +235,14 @@ impl<TI> TrainEpoch<TI> {
callback.on_train_item(item);
if interrupter.should_stop() {
log::info!("Training interrupted.");
interrupted = true;
break;
}
}

if interrupted {
break;
}
}

callback.on_train_end_epoch(self.epoch);
Expand Down
4 changes: 4 additions & 0 deletions burn-train/src/learner/train_val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ where
);
}

if self.interrupter.should_stop() {
break;
}

let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
epoch_valid.run(&model, &mut self.callback, &self.interrupter);

Expand Down
10 changes: 10 additions & 0 deletions burn-train/src/metric/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ pub struct MetricEntry {
/// The string to be saved.
pub serialize: String,
}

/// Format a float with the given precision. Will use scientific notation if necessary.
pub fn format_float(float: f64, precision: usize) -> String {
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);

match scientific_notation_threshold >= float {
true => format!("{float:.precision$e}"),
false => format!("{float:.precision$}"),
}
}
8 changes: 7 additions & 1 deletion burn-train/src/metric/cpu_use.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// The CPU use metric.
use super::MetricMetadata;
use super::{MetricMetadata, Numeric};
use crate::metric::{Metric, MetricEntry};
use sysinfo::{CpuExt, System, SystemExt};

Expand Down Expand Up @@ -59,3 +59,9 @@ impl Metric for CpuUse {

fn clear(&mut self) {}
}

impl Numeric for CpuUse {
fn value(&self) -> f64 {
self.use_percentage as f64
}
}
Loading

0 comments on commit 57d6a56

Please sign in to comment.