diff --git a/burn-book/src/basic-workflow/training.md b/burn-book/src/basic-workflow/training.md index 7cf2e195f9..286270342b 100644 --- a/burn-book/src/basic-workflow/training.md +++ b/burn-book/src/basic-workflow/training.md @@ -111,10 +111,10 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, device: B .build(MNISTDataset::test()); let learner = LearnerBuilder::new(artifact_dir) - .metric_train_plot(AccuracyMetric::new()) - .metric_valid_plot(AccuracyMetric::new()) - .metric_train_plot(LossMetric::new()) - .metric_valid_plot(LossMetric::new()) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(1, CompactRecorder::new()) .devices(vec![device]) .num_epochs(config.num_epochs) diff --git a/burn-train/src/callback/async_callback.rs b/burn-train/src/callback/async_callback.rs deleted file mode 100644 index 2e0fcbc294..0000000000 --- a/burn-train/src/callback/async_callback.rs +++ /dev/null @@ -1,97 +0,0 @@ -use super::{LearnerCallback, LearnerItem}; -use std::{sync::mpsc, thread::JoinHandle}; - -enum Message { - LogTrain(LearnerItem), - LogValid(LearnerItem), - ClearTrain(usize), - ClearValid(usize), - End, -} - -/// Async trainer callback tracker. -pub struct AsyncTrainerCallback { - sender: mpsc::Sender>, - handler: Option>, -} - -#[derive(new)] -struct CallbackThread { - callback: C, - receiver: mpsc::Receiver>, -} - -impl CallbackThread -where - C: LearnerCallback, -{ - fn run(mut self) { - for item in self.receiver.iter() { - match item { - Message::LogTrain(item) => { - self.callback.on_train_item(item); - } - Message::ClearTrain(epoch) => { - self.callback.on_train_end_epoch(epoch); - } - Message::LogValid(item) => { - self.callback.on_valid_item(item); - } - Message::ClearValid(epoch) => { - self.callback.on_valid_end_epoch(epoch); - } - Message::End => { - return; - } - } - } - } -} - -impl AsyncTrainerCallback { - /// Create a new async trainer callback. - pub fn new(callback: C) -> Self - where - C: LearnerCallback + 'static, - { - let (sender, receiver) = mpsc::channel(); - let thread = CallbackThread::new(callback, receiver); - - let handler = std::thread::spawn(move || thread.run()); - let handler = Some(handler); - - Self { sender, handler } - } -} - -impl LearnerCallback for AsyncTrainerCallback { - type ItemTrain = T; - type ItemValid = V; - - fn on_train_item(&mut self, item: LearnerItem) { - self.sender.send(Message::LogTrain(item)).unwrap(); - } - - fn on_valid_item(&mut self, item: LearnerItem) { - self.sender.send(Message::LogValid(item)).unwrap(); - } - - fn on_train_end_epoch(&mut self, epoch: usize) { - self.sender.send(Message::ClearTrain(epoch)).unwrap(); - } - - fn on_valid_end_epoch(&mut self, epoch: usize) { - self.sender.send(Message::ClearValid(epoch)).unwrap(); - } -} - -impl Drop for AsyncTrainerCallback { - fn drop(&mut self) { - self.sender.send(Message::End).unwrap(); - let handler = self.handler.take(); - - if let Some(handler) = handler { - handler.join().unwrap(); - } - } -} diff --git a/burn-train/src/callback/base.rs b/burn-train/src/callback/base.rs deleted file mode 100644 index 6750522104..0000000000 --- a/burn-train/src/callback/base.rs +++ /dev/null @@ -1,43 +0,0 @@ -use burn_core::{data::dataloader::Progress, LearningRate}; - -/// The base trait for trainer callbacks. -pub trait LearnerCallback: Send { - /// Training item. - type ItemTrain; - /// Validation item. - type ItemValid; - - /// Called when a training item is logged. - fn on_train_item(&mut self, _item: LearnerItem) {} - - /// Called when a validation item is logged. - fn on_valid_item(&mut self, _item: LearnerItem) {} - - /// Called when a training epoch is finished. - fn on_train_end_epoch(&mut self, _epoch: usize) {} - - /// Called when a validation epoch is finished. - fn on_valid_end_epoch(&mut self, _epoch: usize) {} -} - -/// A learner item. -#[derive(new)] -pub struct LearnerItem { - /// The item. - pub item: T, - - /// The progress. - pub progress: Progress, - - /// The epoch. - pub epoch: usize, - - /// The total number of epochs. - pub epoch_total: usize, - - /// The iteration. - pub iteration: usize, - - /// The learning rate. - pub lr: Option, -} diff --git a/burn-train/src/callback/mod.rs b/burn-train/src/callback/mod.rs deleted file mode 100644 index 809b581de4..0000000000 --- a/burn-train/src/callback/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod async_callback; -mod base; - -pub use async_callback::*; -pub use base::*; diff --git a/burn-train/src/collector/async_collector.rs b/burn-train/src/collector/async_collector.rs new file mode 100644 index 0000000000..eb4f75d705 --- /dev/null +++ b/burn-train/src/collector/async_collector.rs @@ -0,0 +1,118 @@ +use super::EventCollector; +use crate::{Aggregate, Direction, Event, Split}; +use std::{sync::mpsc, thread::JoinHandle}; + +enum Message { + OnEventTrain(Event), + OnEventValid(Event), + End, + FindEpoch( + String, + Aggregate, + Direction, + Split, + mpsc::SyncSender>, + ), +} + +/// Async [event collector](EventCollector). +/// +/// This will create a worker thread where all the computation is done ensuring that the training loop is +/// never blocked by metric calculation. +pub struct AsyncEventCollector { + sender: mpsc::Sender>, + handler: Option>, +} + +#[derive(new)] +struct WorkerThread { + collector: C, + receiver: mpsc::Receiver>, +} + +impl WorkerThread +where + C: EventCollector, +{ + fn run(mut self) { + for item in self.receiver.iter() { + match item { + Message::End => { + return; + } + Message::FindEpoch(name, aggregate, direction, split, sender) => { + let response = self + .collector + .find_epoch(&name, aggregate, direction, split); + sender.send(response).unwrap(); + } + Message::OnEventTrain(event) => self.collector.on_event_train(event), + Message::OnEventValid(event) => self.collector.on_event_valid(event), + } + } + } +} + +impl AsyncEventCollector { + /// Create a new async [event collector](EventCollector). + pub fn new(collector: C) -> Self + where + C: EventCollector + 'static, + { + let (sender, receiver) = mpsc::channel(); + let thread = WorkerThread::new(collector, receiver); + + let handler = std::thread::spawn(move || thread.run()); + let handler = Some(handler); + + Self { sender, handler } + } +} + +impl EventCollector for AsyncEventCollector { + type ItemTrain = T; + type ItemValid = V; + + fn on_event_train(&mut self, event: Event) { + self.sender.send(Message::OnEventTrain(event)).unwrap(); + } + + fn on_event_valid(&mut self, event: Event) { + self.sender.send(Message::OnEventValid(event)).unwrap(); + } + + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self.sender + .send(Message::FindEpoch( + name.to_string(), + aggregate, + direction, + split, + sender, + )) + .unwrap(); + + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Async server crashed: {:?}", err), + } + } +} + +impl Drop for AsyncEventCollector { + fn drop(&mut self) { + self.sender.send(Message::End).unwrap(); + let handler = self.handler.take(); + + if let Some(handler) = handler { + handler.join().unwrap(); + } + } +} diff --git a/burn-train/src/collector/base.rs b/burn-train/src/collector/base.rs new file mode 100644 index 0000000000..54be28cd36 --- /dev/null +++ b/burn-train/src/collector/base.rs @@ -0,0 +1,78 @@ +use burn_core::{data::dataloader::Progress, LearningRate}; + +/// Event happening during the training/validation process. +pub enum Event { + /// Signal that an item have been processed. + ProcessedItem(LearnerItem), + /// Signal the end of an epoch. + EndEpoch(usize), +} + +/// Defines how training and validation events are collected. +/// +/// This trait also exposes methods that uses the collected data to compute useful information. +pub trait EventCollector: Send { + /// Training item. + type ItemTrain; + /// Validation item. + type ItemValid; + + /// Collect the training event. + fn on_event_train(&mut self, event: Event); + + /// Collect the validaion event. + fn on_event_valid(&mut self, event: Event); + + /// Find the epoch following the given criteria from the collected data. + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option; +} + +/// How to aggregate the metric. +pub enum Aggregate { + /// Compute the average. + Mean, +} + +/// The split to use. +pub enum Split { + /// The training split. + Train, + /// The validation split. + Valid, +} + +/// The direction of the query. +pub enum Direction { + /// Lower is better. + Lowest, + /// Higher is better. + Highest, +} + +/// A learner item. +#[derive(new)] +pub struct LearnerItem { + /// The item. + pub item: T, + + /// The progress. + pub progress: Progress, + + /// The epoch. + pub epoch: usize, + + /// The total number of epochs. + pub epoch_total: usize, + + /// The iteration. + pub iteration: usize, + + /// The learning rate. + pub lr: Option, +} diff --git a/burn-train/src/collector/metrics/base.rs b/burn-train/src/collector/metrics/base.rs new file mode 100644 index 0000000000..4b1ad2af47 --- /dev/null +++ b/burn-train/src/collector/metrics/base.rs @@ -0,0 +1,131 @@ +use crate::{ + info::MetricsInfo, + metric::MetricMetadata, + renderer::{MetricState, MetricsRenderer, TrainingProgress}, + Aggregate, Direction, Event, EventCollector, LearnerItem, Split, +}; + +/// Collect training events in order to display metrics with a metrics renderer. +#[derive(new)] +pub(crate) struct RenderedMetricsEventCollector +where + T: Send + Sync + 'static, + V: Send + Sync + 'static, +{ + renderer: Box, + info: MetricsInfo, +} + +impl EventCollector for RenderedMetricsEventCollector +where + T: Send + Sync + 'static, + V: Send + Sync + 'static, +{ + type ItemTrain = T; + type ItemValid = V; + + fn on_event_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => self.on_train_item(item), + Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch), + } + } + + fn on_event_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => self.on_valid_item(item), + Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch), + } + } + + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + self.info.find_epoch(name, aggregate, direction, split) + } +} + +impl RenderedMetricsEventCollector +where + T: Send + Sync + 'static, + V: Send + Sync + 'static, +{ + fn on_train_item(&mut self, item: LearnerItem) { + let progress = (&item).into(); + let metadata = (&item).into(); + + let update = self.info.update_train(&item, &metadata); + + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); + + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self.renderer + .update_train(MetricState::Numeric(entry, value)) + }); + + self.renderer.render_train(progress); + } + + fn on_valid_item(&mut self, item: LearnerItem) { + let progress = (&item).into(); + let metadata = (&item).into(); + + let update = self.info.update_valid(&item, &metadata); + + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); + + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self.renderer + .update_valid(MetricState::Numeric(entry, value)) + }); + + self.renderer.render_train(progress); + } + + fn on_train_end_epoch(&mut self, epoch: usize) { + self.info.end_epoch_train(epoch); + } + + fn on_valid_end_epoch(&mut self, epoch: usize) { + self.info.end_epoch_valid(epoch); + } +} + +impl From<&LearnerItem> for TrainingProgress { + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + } + } +} + +impl From<&LearnerItem> for MetricMetadata { + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + lr: item.lr, + } + } +} diff --git a/burn-train/src/collector/metrics/mod.rs b/burn-train/src/collector/metrics/mod.rs new file mode 100644 index 0000000000..41e113f920 --- /dev/null +++ b/burn-train/src/collector/metrics/mod.rs @@ -0,0 +1,3 @@ +mod base; + +pub(crate) use base::*; diff --git a/burn-train/src/collector/mod.rs b/burn-train/src/collector/mod.rs new file mode 100644 index 0000000000..ffb89eb298 --- /dev/null +++ b/burn-train/src/collector/mod.rs @@ -0,0 +1,8 @@ +mod async_collector; +mod base; + +pub use async_collector::*; +pub use base::*; + +/// Metrics collector module. +pub mod metrics; diff --git a/burn-train/src/components.rs b/burn-train/src/components.rs index 3318484d40..0ae45057df 100644 --- a/burn-train/src/components.rs +++ b/burn-train/src/components.rs @@ -1,4 +1,4 @@ -use crate::{checkpoint::Checkpointer, LearnerCallback}; +use crate::{checkpoint::Checkpointer, EventCollector}; use burn_core::{ lr_scheduler::LrScheduler, module::{ADModule, Module}, @@ -25,8 +25,8 @@ pub trait LearnerComponents { >; /// The checkpointer used for the scheduler. type CheckpointerLrScheduler: Checkpointer<::Record>; - /// Callback used for training tracking. - type Callback: LearnerCallback + 'static; + /// Training event collector used for training tracking. + type EventCollector: EventCollector + 'static; } /// Concrete type that implements [training components trait](TrainingComponents). @@ -41,8 +41,8 @@ pub struct LearnerComponentsMarker { _callback: PhantomData, } -impl LearnerComponents - for LearnerComponentsMarker +impl LearnerComponents + for LearnerComponentsMarker where B: ADBackend, LR: LrScheduler, @@ -51,7 +51,7 @@ where CM: Checkpointer, CO: Checkpointer, CS: Checkpointer, - C: LearnerCallback + 'static, + EC: EventCollector + 'static, { type Backend = B; type LrScheduler = LR; @@ -60,5 +60,5 @@ where type CheckpointerModel = CM; type CheckpointerOptimizer = CO; type CheckpointerLrScheduler = CS; - type Callback = C; + type EventCollector = EC; } diff --git a/burn-train/src/info/aggregates.rs b/burn-train/src/info/aggregates.rs new file mode 100644 index 0000000000..e270b14ac9 --- /dev/null +++ b/burn-train/src/info/aggregates.rs @@ -0,0 +1,163 @@ +use crate::{logger::MetricLogger, Aggregate, Direction}; +use std::collections::HashMap; + +/// Type that can be used to fetch and use numeric metric aggregates. +#[derive(Default, Debug)] +pub(crate) struct NumericMetricsAggregate { + mean_for_each_epoch: HashMap, +} + +#[derive(new, Hash, PartialEq, Eq, Debug)] +struct Key { + name: String, + epoch: usize, +} + +impl NumericMetricsAggregate { + pub(crate) fn mean( + &mut self, + name: &str, + epoch: usize, + loggers: &mut [Box], + ) -> Option { + let key = Key::new(name.to_string(), epoch); + + if let Some(value) = self.mean_for_each_epoch.get(&key) { + return Some(*value); + } + + let points = || { + let mut errors = Vec::new(); + for logger in loggers { + match logger.read_numeric(name, epoch) { + Ok(points) => return Ok(points), + Err(err) => errors.push(err), + }; + } + + Err(errors.join(" ")) + }; + + let points = points().expect("Can read values"); + + if points.is_empty() { + return None; + } + + let num_points = points.len(); + let mean = points.into_iter().sum::() / num_points as f64; + + self.mean_for_each_epoch.insert(key, mean); + Some(mean) + } + + pub(crate) fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + loggers: &mut [Box], + ) -> Option { + let mut data = Vec::new(); + let mut current_epoch = 1; + + loop { + match aggregate { + Aggregate::Mean => match self.mean(name, current_epoch, loggers) { + Some(value) => { + data.push(value); + } + None => break, + }, + }; + + current_epoch += 1; + } + + if data.is_empty() { + return None; + } + + let mut current_value = match &direction { + Direction::Lowest => f64::MAX, + Direction::Highest => f64::MIN, + }; + + for (i, value) in data.into_iter().enumerate() { + match &direction { + Direction::Lowest => { + if value < current_value { + current_value = value; + current_epoch = i + 1; + } + } + Direction::Highest => { + if value > current_value { + current_value = value; + current_epoch = i + 1; + } + } + } + } + + Some(current_epoch) + } +} + +#[cfg(test)] +mod tests { + use crate::{logger::FileMetricLogger, metric::MetricEntry}; + + use super::*; + + struct TestLogger { + logger: FileMetricLogger, + epoch: usize, + } + const NAME: &str = "test-logger"; + + impl TestLogger { + fn new() -> Self { + Self { + logger: FileMetricLogger::new("/tmp"), + epoch: 1, + } + } + fn log(&mut self, num: f64) { + self.logger.log(&MetricEntry::new( + NAME.into(), + num.to_string(), + num.to_string(), + )); + } + fn new_epoch(&mut self) { + self.epoch += 1; + self.logger.epoch(self.epoch); + } + } + + #[test] + fn should_find_epoch() { + let mut logger = TestLogger::new(); + let mut aggregate = NumericMetricsAggregate::default(); + + logger.log(500.); // Epoch 1 + logger.log(1000.); // Epoch 1 + logger.new_epoch(); + logger.log(200.); // Epoch 2 + logger.log(1000.); // Epoch 2 + logger.new_epoch(); + logger.log(10000.); // Epoch 3 + + let value = aggregate + .find_epoch( + NAME, + Aggregate::Mean, + Direction::Lowest, + &mut [Box::new(logger.logger)], + ) + .unwrap(); + + assert_eq!(value, 2); + } +} diff --git a/burn-train/src/info/metrics.rs b/burn-train/src/info/metrics.rs new file mode 100644 index 0000000000..6a94b00636 --- /dev/null +++ b/burn-train/src/info/metrics.rs @@ -0,0 +1,253 @@ +use super::NumericMetricsAggregate; +use crate::{ + logger::MetricLogger, + metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, + Aggregate, Direction, LearnerItem, Split, +}; + +/// Metrics information collected during training. +pub struct MetricsInfo +where + T: Send + Sync + 'static, + V: Send + Sync + 'static, +{ + train: Vec>>, + valid: Vec>>, + train_numeric: Vec>>, + valid_numeric: Vec>>, + loggers_train: Vec>, + loggers_valid: Vec>, + aggregate_train: NumericMetricsAggregate, + aggregate_valid: NumericMetricsAggregate, +} + +#[derive(new)] +pub(crate) struct MetricsUpdate { + pub(crate) entries: Vec, + pub(crate) entries_numeric: Vec<(MetricEntry, f64)>, +} + +impl MetricsInfo +where + T: Send + Sync + 'static, + V: Send + Sync + 'static, +{ + pub(crate) fn new() -> Self { + Self { + train: vec![], + valid: vec![], + train_numeric: vec![], + valid_numeric: vec![], + loggers_train: vec![], + loggers_valid: vec![], + aggregate_train: NumericMetricsAggregate::default(), + aggregate_valid: NumericMetricsAggregate::default(), + } + } + + /// Signal the end of a training epoch. + pub(crate) fn end_epoch_train(&mut self, epoch: usize) { + for metric in self.train.iter_mut() { + metric.clear(); + } + for metric in self.train_numeric.iter_mut() { + metric.clear(); + } + for logger in self.loggers_train.iter_mut() { + logger.epoch(epoch + 1); + } + } + + /// Signal the end of a validation epoch. + pub(crate) fn end_epoch_valid(&mut self, epoch: usize) { + for metric in self.valid.iter_mut() { + metric.clear(); + } + for metric in self.valid_numeric.iter_mut() { + metric.clear(); + } + for logger in self.loggers_valid.iter_mut() { + logger.epoch(epoch + 1); + } + } + + /// Update the training information from the training item. + pub(crate) fn update_train( + &mut self, + item: &LearnerItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.train.len()); + let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); + + for metric in self.train.iter_mut() { + let state = metric.update(item, metadata); + + for logger in self.loggers_train.iter_mut() { + logger.log(&state); + } + + entries.push(state); + } + + for metric in self.train_numeric.iter_mut() { + let (state, value) = metric.update(item, metadata); + for logger in self.loggers_train.iter_mut() { + logger.log(&state); + } + + entries_numeric.push((state, value)); + } + + MetricsUpdate::new(entries, entries_numeric) + } + + /// Update the training information from the validation item. + pub(crate) fn update_valid( + &mut self, + item: &LearnerItem, + metadata: &MetricMetadata, + ) -> MetricsUpdate { + let mut entries = Vec::with_capacity(self.valid.len()); + let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); + + for metric in self.valid.iter_mut() { + let state = metric.update(item, metadata); + + for logger in self.loggers_valid.iter_mut() { + logger.log(&state); + } + + entries.push(state); + } + + for metric in self.valid_numeric.iter_mut() { + let (state, value) = metric.update(item, metadata); + for logger in self.loggers_valid.iter_mut() { + logger.log(&state); + } + + entries_numeric.push((state, value)); + } + + MetricsUpdate::new(entries, entries_numeric) + } + + /// Find the epoch corresponding to the given criteria. + pub(crate) fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + match split { + Split::Train => { + self.aggregate_train + .find_epoch(name, aggregate, direction, &mut self.loggers_train) + } + Split::Valid => { + self.aggregate_valid + .find_epoch(name, aggregate, direction, &mut self.loggers_valid) + } + } + } + + /// Register a logger for training metrics. + pub(crate) fn register_logger_train(&mut self, logger: ML) { + self.loggers_train.push(Box::new(logger)); + } + + /// Register a logger for validation metrics. + pub(crate) fn register_logger_valid(&mut self, logger: ML) { + self.loggers_valid.push(Box::new(logger)); + } + + /// Register a training metric. + pub(crate) fn register_metric_train(&mut self, metric: Me) + where + T: Adaptor, + { + let metric = MetricWrapper::new(metric); + self.train.push(Box::new(metric)) + } + + /// Register a validation metric. + pub(crate) fn register_valid_metric(&mut self, metric: Me) + where + V: Adaptor, + { + let metric = MetricWrapper::new(metric); + self.valid.push(Box::new(metric)) + } + + /// Register a numeric training metric. + pub(crate) fn register_train_metric_numeric( + &mut self, + metric: Me, + ) where + T: Adaptor, + { + let metric = MetricWrapper::new(metric); + self.train_numeric.push(Box::new(metric)) + } + + /// Register a numeric validation metric. + pub(crate) fn register_valid_metric_numeric( + &mut self, + metric: Me, + ) where + V: Adaptor, + { + let metric = MetricWrapper::new(metric); + self.valid_numeric.push(Box::new(metric)) + } +} + +trait NumericMetricUpdater: Send + Sync { + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64); + fn clear(&mut self); +} + +trait MetricUpdater: Send + Sync { + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; + fn clear(&mut self); +} + +#[derive(new)] +struct MetricWrapper { + metric: M, +} + +impl NumericMetricUpdater for MetricWrapper +where + T: 'static, + M: Metric + Numeric + 'static, + T: Adaptor, +{ + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64) { + let update = self.metric.update(&item.item.adapt(), metadata); + let numeric = self.metric.value(); + + (update, numeric) + } + + fn clear(&mut self) { + self.metric.clear() + } +} + +impl MetricUpdater for MetricWrapper +where + T: 'static, + M: Metric + 'static, + T: Adaptor, +{ + fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { + self.metric.update(&item.item.adapt(), metadata) + } + + fn clear(&mut self) { + self.metric.clear() + } +} diff --git a/burn-train/src/info/mod.rs b/burn-train/src/info/mod.rs new file mode 100644 index 0000000000..0adbc16d10 --- /dev/null +++ b/burn-train/src/info/mod.rs @@ -0,0 +1,5 @@ +mod aggregates; +mod metrics; + +pub(crate) use aggregates::*; +pub use metrics::*; diff --git a/burn-train/src/learner/base.rs b/burn-train/src/learner/base.rs index a8ded298d9..2672e755ed 100644 --- a/burn-train/src/learner/base.rs +++ b/burn-train/src/learner/base.rs @@ -19,7 +19,7 @@ pub struct Learner { pub(crate) grad_accumulation: Option, pub(crate) checkpointer: Option>, pub(crate) devices: Vec<::Device>, - pub(crate) callback: LC::Callback, + pub(crate) collector: LC::EventCollector, pub(crate) interrupter: TrainingInterrupter, } diff --git a/burn-train/src/learner/builder.rs b/burn-train/src/learner/builder.rs index 723d821ecf..da19502b52 100644 --- a/burn-train/src/learner/builder.rs +++ b/burn-train/src/learner/builder.rs @@ -1,14 +1,14 @@ use super::log::install_file_logger; use super::Learner; use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer}; +use crate::collector::metrics::RenderedMetricsEventCollector; use crate::components::LearnerComponentsMarker; +use crate::info::MetricsInfo; use crate::learner::base::TrainingInterrupter; use crate::logger::{FileMetricLogger, MetricLogger}; -use crate::metric::callback::{ - default_renderer, MetricWrapper, Metrics, MetricsCallback, MetricsRenderer, -}; use crate::metric::{Adaptor, Metric}; -use crate::{AsyncTrainerCallback, LearnerCheckpointer}; +use crate::renderer::{default_renderer, MetricsRenderer}; +use crate::{AsyncEventCollector, LearnerCheckpointer}; use burn_core::lr_scheduler::LrScheduler; use burn_core::module::ADModule; use burn_core::optim::Optimizer; @@ -39,12 +39,11 @@ where directory: String, grad_accumulation: Option, devices: Vec, - metric_logger_train: Option>, - metric_logger_valid: Option>, renderer: Option>, - metrics: Metrics, + info: MetricsInfo, interrupter: TrainingInterrupter, log_to_file: bool, + num_loggers: usize, } impl LearnerBuilder @@ -69,12 +68,11 @@ where directory: directory.to_string(), grad_accumulation: None, devices: vec![B::Device::default()], - metric_logger_train: None, - metric_logger_valid: None, - metrics: Metrics::new(), + info: MetricsInfo::new(), renderer: None, interrupter: TrainingInterrupter::new(), log_to_file: true, + num_loggers: 0, } } @@ -89,8 +87,9 @@ where MT: MetricLogger + 'static, MV: MetricLogger + 'static, { - self.metric_logger_train = Some(Box::new(logger_train)); - self.metric_logger_valid = Some(Box::new(logger_valid)); + self.info.register_logger_train(logger_train); + self.info.register_logger_valid(logger_valid); + self.num_loggers += 1; self } @@ -112,9 +111,7 @@ where where T: Adaptor, { - self.metrics - .train - .push(Box::new(MetricWrapper::new(metric))); + self.info.register_metric_train(metric); self } @@ -123,9 +120,7 @@ where where V: Adaptor, { - self.metrics - .valid - .push(Box::new(MetricWrapper::new(metric))); + self.info.register_valid_metric(metric); self } @@ -144,41 +139,25 @@ where self } - /// Register a training metric and displays it on a plot. - /// - /// # Notes - /// - /// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot. - /// If the same metric is also registered for the [validation split](Self::metric_valid_plot), - /// the same graph will be used for both. - pub fn metric_train_plot(mut self, metric: Me) -> Self + /// Register a [numeric](crate::metric::Numeric) training [metric](Metric). + pub fn metric_train_numeric(mut self, metric: Me) -> Self where Me: Metric + crate::metric::Numeric + 'static, T: Adaptor, { - self.metrics - .train_numeric - .push(Box::new(MetricWrapper::new(metric))); + self.info.register_train_metric_numeric(metric); self } - /// Register a validation metric and displays it on a plot. - /// - /// # Notes - /// - /// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot. - /// If the same metric is also registered for the [training split](Self::metric_train_plot), - /// the same graph will be used for both. - pub fn metric_valid_plot( + /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric). + pub fn metric_valid_numeric( mut self, metric: Me, ) -> Self where V: Adaptor, { - self.metrics - .valid_numeric - .push(Box::new(MetricWrapper::new(metric))); + self.info.register_valid_metric_numeric(metric); self } @@ -260,7 +239,7 @@ where #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and // creates a clean learner. pub fn build( - self, + mut self, model: M, optim: O, lr_scheduler: S, @@ -273,7 +252,7 @@ where AsyncCheckpointer, AsyncCheckpointer, AsyncCheckpointer, - AsyncTrainerCallback, + AsyncEventCollector, >, > where @@ -288,18 +267,18 @@ where Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) }); let directory = &self.directory; - let logger_train = self.metric_logger_train.unwrap_or_else(|| { - Box::new(FileMetricLogger::new(format!("{directory}/train").as_str())) - }); - let logger_valid = self.metric_logger_valid.unwrap_or_else(|| { - Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str())) - }); - let callback = AsyncTrainerCallback::new(MetricsCallback::new( - renderer, - self.metrics, - logger_train, - logger_valid, - )); + + if self.num_loggers == 0 { + self.info.register_logger_train(FileMetricLogger::new( + format!("{directory}/train").as_str(), + )); + self.info.register_logger_valid(FileMetricLogger::new( + format!("{directory}/valid").as_str(), + )); + } + + let collector = + AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info)); let checkpointer = self .checkpointers @@ -311,7 +290,7 @@ where lr_scheduler, checkpointer, num_epochs: self.num_epochs, - callback, + collector, checkpoint: self.checkpoint, grad_accumulation: self.grad_accumulation, devices: self.devices, diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index b9d109bbf2..390e1024bb 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -7,8 +7,8 @@ use burn_core::{ }; use std::sync::Arc; -use crate::learner::base::TrainingInterrupter; -use crate::{LearnerCallback, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep}; +use crate::{learner::base::TrainingInterrupter, Event}; +use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep}; /// A validation epoch. #[derive(new)] @@ -37,7 +37,7 @@ impl ValidEpoch { pub fn run( &self, model: &M, - callback: &mut Box>, + callback: &mut Box>, interrupter: &TrainingInterrupter, ) where B: ADBackend, @@ -64,13 +64,14 @@ impl ValidEpoch { None, ); - callback.on_valid_item(item); + callback.on_event_valid(Event::ProcessedItem(item)); + if interrupter.should_stop() { log::info!("Training interrupted."); break; } } - callback.on_valid_end_epoch(self.epoch); + callback.on_event_valid(Event::EndEpoch(self.epoch)); } } @@ -92,7 +93,7 @@ impl TrainEpoch { mut model: M, mut optim: O, scheduler: &mut LR, - callback: &mut Box>, + callback: &mut Box>, interrupter: &TrainingInterrupter, ) -> (M, O) where @@ -139,13 +140,13 @@ impl TrainEpoch { Some(lr), ); - callback.on_train_item(item); + callback.on_event_train(Event::ProcessedItem(item)); if interrupter.should_stop() { log::info!("Training interrupted."); break; } } - callback.on_train_end_epoch(self.epoch); + callback.on_event_train(Event::EndEpoch(self.epoch)); (model, optim) } @@ -170,7 +171,7 @@ impl TrainEpoch { mut model: M, mut optim: O, lr_scheduler: &mut S, - callback: &mut Box>, + callback: &mut Box>, devices: Vec, interrupter: &TrainingInterrupter, ) -> (M, O) @@ -232,7 +233,8 @@ impl TrainEpoch { Some(lr), ); - callback.on_train_item(item); + callback.on_event_train(Event::ProcessedItem(item)); + if interrupter.should_stop() { log::info!("Training interrupted."); interrupted = true; @@ -245,7 +247,7 @@ impl TrainEpoch { } } - callback.on_train_end_epoch(self.epoch); + callback.on_event_train(Event::EndEpoch(self.epoch)); (model, optim) } diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index 09fa6487b0..27fe1284f8 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -1,5 +1,5 @@ use crate::components::LearnerComponents; -use crate::{Learner, LearnerCallback, TrainEpoch, ValidEpoch}; +use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch}; use burn_core::data::dataloader::DataLoader; use burn_core::module::{ADModule, Module}; use burn_core::optim::{GradientsParams, Optimizer}; @@ -115,7 +115,7 @@ impl Learner { OutputValid: Send, LC::Model: TrainStep, >::InnerModule: ValidStep, - LC::Callback: LearnerCallback, + LC::EventCollector: EventCollector, { log::info!("Fitting {}", self.model.to_string()); // The reference model is always on the first device provided. @@ -139,8 +139,8 @@ impl Learner { }; let mut callback: Box< - dyn LearnerCallback, - > = Box::new(self.callback); + dyn EventCollector, + > = Box::new(self.collector); for epoch in starting_epoch..self.num_epochs + 1 { let epoch_train = TrainEpoch::new( diff --git a/burn-train/src/lib.rs b/burn-train/src/lib.rs index 6ea64ed29b..23d1a5e817 100644 --- a/burn-train/src/lib.rs +++ b/burn-train/src/lib.rs @@ -10,16 +10,22 @@ pub mod checkpoint; pub(crate) mod components; +/// Renderer modules to display metrics and training information. +pub mod renderer; + /// The logger module. pub mod logger; /// The metric module. pub mod metric; -mod callback; +/// All information collected during training. +pub mod info; + +mod collector; mod learner; -pub use callback::*; +pub use collector::*; pub use learner::*; #[cfg(test)] diff --git a/burn-train/src/logger/async_logger.rs b/burn-train/src/logger/async_logger.rs index a9b8e1c654..f771f57639 100644 --- a/burn-train/src/logger/async_logger.rs +++ b/burn-train/src/logger/async_logger.rs @@ -4,6 +4,7 @@ use std::sync::mpsc; enum Message { Log(T), End, + Sync(mpsc::Sender<()>), } /// Async logger. pub struct AsyncLogger { @@ -30,6 +31,9 @@ where Message::End => { return; } + Message::Sync(callback) => { + callback.send(()).unwrap(); + } } } } @@ -48,6 +52,17 @@ impl AsyncLogger { Self { sender, handler } } + + /// Sync the async logger. + pub(crate) fn sync(&self) { + let (sender, receiver) = mpsc::channel(); + + self.sender.send(Message::Sync(sender)).unwrap(); + + receiver + .recv() + .expect("Should sync, otherwise the thread is dead."); + } } impl Logger for AsyncLogger { diff --git a/burn-train/src/logger/metric.rs b/burn-train/src/logger/metric.rs index 4c3f2128a8..127ab022b6 100644 --- a/burn-train/src/logger/metric.rs +++ b/burn-train/src/logger/metric.rs @@ -49,11 +49,14 @@ impl FileMetricLogger { fn file_path(&self, name: &str, epoch: usize) -> String { let directory = format!("{}/epoch-{}", self.directory, epoch); - std::fs::create_dir_all(&directory).ok(); let name = name.replace(' ', "_"); format!("{directory}/{name}.log") } + fn create_directory(&self, epoch: usize) { + let directory = format!("{}/epoch-{}", self.directory, epoch); + std::fs::create_dir_all(directory).ok(); + } } impl MetricLogger for FileMetricLogger { @@ -64,6 +67,8 @@ impl MetricLogger for FileMetricLogger { let logger = match self.loggers.get_mut(key) { Some(val) => val, None => { + self.create_directory(self.epoch); + let file_path = self.file_path(key, self.epoch); let logger = FileLogger::new(&file_path); let logger = AsyncLogger::new(logger); @@ -82,6 +87,10 @@ impl MetricLogger for FileMetricLogger { } fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + if let Some(value) = self.loggers.get(name) { + value.sync() + } + let file_path = self.file_path(name, epoch); let mut errors = false; diff --git a/burn-train/src/metric/callback/base.rs b/burn-train/src/metric/callback/base.rs deleted file mode 100644 index 91b20740a5..0000000000 --- a/burn-train/src/metric/callback/base.rs +++ /dev/null @@ -1,285 +0,0 @@ -use crate::{ - logger::MetricLogger, - metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, - LearnerCallback, LearnerItem, -}; -use burn_core::data::dataloader::Progress; - -/// Holds all metrics, metric loggers, and a metrics renderer. -pub struct MetricsCallback -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - metrics: Metrics, - logger_train: Box, - logger_valid: Box, - renderer: Box, -} - -impl MetricsCallback -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - /// Creates a new metrics callback. - /// - /// # Arguments - /// - /// * `renderer` - The metrics renderer. - /// * `metrics` - The metrics holder. - /// * `logger_train` - The training logger. - /// * `logger_valid` - The validation logger. - /// - /// # Returns - /// - /// A new metrics callback. - pub(crate) fn new( - renderer: Box, - metrics: Metrics, - logger_train: Box, - logger_valid: Box, - ) -> Self { - Self { - metrics, - logger_train, - logger_valid, - renderer, - } - } -} - -impl LearnerCallback for MetricsCallback -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - type ItemTrain = T; - type ItemValid = V; - - fn on_train_item(&mut self, item: LearnerItem) { - let metadata = (&item).into(); - for metric in self.metrics.train.iter_mut() { - let state = metric.update(&item, &metadata); - self.logger_train.log(&state); - - self.renderer.update_train(MetricState::Generic(state)); - } - for metric in self.metrics.train_numeric.iter_mut() { - let (state, value) = metric.update(&item, &metadata); - self.logger_train.log(&state); - - self.renderer - .update_train(MetricState::Numeric(state, value)); - } - self.renderer.render_train(item.into()); - } - - fn on_valid_item(&mut self, item: LearnerItem) { - let metadata = (&item).into(); - for metric in self.metrics.valid.iter_mut() { - let state = metric.update(&item, &metadata); - self.logger_valid.log(&state); - - self.renderer.update_valid(MetricState::Generic(state)); - } - for metric in self.metrics.valid_numeric.iter_mut() { - let (state, value) = metric.update(&item, &metadata); - self.logger_valid.log(&state); - - self.renderer - .update_valid(MetricState::Numeric(state, value)); - } - self.renderer.render_valid(item.into()); - } - - fn on_train_end_epoch(&mut self, epoch: usize) { - for metric in self.metrics.train.iter_mut() { - metric.clear(); - } - for metric in self.metrics.train_numeric.iter_mut() { - metric.clear(); - } - self.logger_train.epoch(epoch + 1); - } - - fn on_valid_end_epoch(&mut self, epoch: usize) { - for metric in self.metrics.valid.iter_mut() { - metric.clear(); - } - for metric in self.metrics.valid_numeric.iter_mut() { - metric.clear(); - } - self.logger_valid.epoch(epoch + 1); - } -} - -/// Training progress. -#[derive(Debug)] -pub struct TrainingProgress { - /// The progress. - pub progress: Progress, - - /// The epoch. - pub epoch: usize, - - /// The total number of epochs. - pub epoch_total: usize, - - /// The iteration. - pub iteration: usize, -} - -impl TrainingProgress { - /// Creates a new empty training progress. - pub fn none() -> Self { - Self { - progress: Progress { - items_processed: 0, - items_total: 0, - }, - epoch: 0, - epoch_total: 0, - iteration: 0, - } - } -} - -/// The state of a metric. -#[derive(Debug)] -pub enum MetricState { - /// A generic metric. - Generic(MetricEntry), - - /// A numeric metric. - Numeric(MetricEntry, f64), -} - -/// Trait for rendering metrics. -pub trait MetricsRenderer: Send + Sync { - /// Updates the training metric state. - /// - /// # Arguments - /// - /// * `state` - The metric state. - fn update_train(&mut self, state: MetricState); - - /// Updates the validation metric state. - /// - /// # Arguments - /// - /// * `state` - The metric state. - fn update_valid(&mut self, state: MetricState); - - /// Renders the training progress. - /// - /// # Arguments - /// - /// * `item` - The training progress. - fn render_train(&mut self, item: TrainingProgress); - - /// Renders the validation progress. - /// - /// # Arguments - /// - /// * `item` - The validation progress. - fn render_valid(&mut self, item: TrainingProgress); -} - -/// A container for the metrics held by a metrics callback. -pub(crate) struct Metrics -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - pub(crate) train: Vec>>, - pub(crate) valid: Vec>>, - pub(crate) train_numeric: Vec>>, - pub(crate) valid_numeric: Vec>>, -} - -impl Metrics -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - pub fn new() -> Self { - Self { - train: vec![], - valid: vec![], - train_numeric: vec![], - valid_numeric: vec![], - } - } -} - -impl From> for TrainingProgress { - fn from(item: LearnerItem) -> Self { - Self { - progress: item.progress, - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - } - } -} - -impl From<&LearnerItem> for MetricMetadata { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - lr: item.lr, - } - } -} - -pub(crate) trait NumericMetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64); - fn clear(&mut self); -} - -pub(crate) trait MetricUpdater: Send + Sync { - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry; - fn clear(&mut self); -} - -#[derive(new)] -pub(crate) struct MetricWrapper { - metric: M, -} - -impl NumericMetricUpdater for MetricWrapper -where - T: 'static, - M: Metric + Numeric + 'static, - T: Adaptor, -{ - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> (MetricEntry, f64) { - let update = self.metric.update(&item.item.adapt(), metadata); - let numeric = self.metric.value(); - - (update, numeric) - } - - fn clear(&mut self) { - self.metric.clear() - } -} - -impl MetricUpdater for MetricWrapper -where - T: 'static, - M: Metric + 'static, - T: Adaptor, -{ - fn update(&mut self, item: &LearnerItem, metadata: &MetricMetadata) -> MetricEntry { - self.metric.update(&item.item.adapt(), metadata) - } - - fn clear(&mut self) { - self.metric.clear() - } -} diff --git a/burn-train/src/metric/mod.rs b/burn-train/src/metric/mod.rs index c077e40c1b..75ecdc132c 100644 --- a/burn-train/src/metric/mod.rs +++ b/burn-train/src/metric/mod.rs @@ -1,7 +1,4 @@ -/// Callback module for training progress. -pub mod callback; - -/// State module for callback metrics. +/// State module. pub mod state; mod acc; diff --git a/burn-train/src/metric/state.rs b/burn-train/src/metric/state.rs index 37ef077b8c..567249275e 100644 --- a/burn-train/src/metric/state.rs +++ b/burn-train/src/metric/state.rs @@ -1,4 +1,4 @@ -use super::{format_float, MetricEntry, Numeric}; +use crate::metric::{format_float, MetricEntry, Numeric}; /// Usefull utility to implement numeric metrics. /// diff --git a/burn-train/src/renderer/base.rs b/burn-train/src/renderer/base.rs new file mode 100644 index 0000000000..6cfc2a5eb0 --- /dev/null +++ b/burn-train/src/renderer/base.rs @@ -0,0 +1,75 @@ +use burn_core::data::dataloader::Progress; + +use crate::metric::MetricEntry; + +/// Trait for rendering metrics. +pub trait MetricsRenderer: Send + Sync { + /// Updates the training metric state. + /// + /// # Arguments + /// + /// * `state` - The metric state. + fn update_train(&mut self, state: MetricState); + + /// Updates the validation metric state. + /// + /// # Arguments + /// + /// * `state` - The metric state. + fn update_valid(&mut self, state: MetricState); + + /// Renders the training progress. + /// + /// # Arguments + /// + /// * `item` - The training progress. + fn render_train(&mut self, item: TrainingProgress); + + /// Renders the validation progress. + /// + /// # Arguments + /// + /// * `item` - The validation progress. + fn render_valid(&mut self, item: TrainingProgress); +} + +/// The state of a metric. +#[derive(Debug)] +pub enum MetricState { + /// A generic metric. + Generic(MetricEntry), + + /// A numeric metric. + Numeric(MetricEntry, f64), +} + +/// Training progress. +#[derive(Debug)] +pub struct TrainingProgress { + /// The progress. + pub progress: Progress, + + /// The epoch. + pub epoch: usize, + + /// The total number of epochs. + pub epoch_total: usize, + + /// The iteration. + pub iteration: usize, +} + +impl TrainingProgress { + /// Creates a new empty training progress. + pub fn none() -> Self { + Self { + progress: Progress { + items_processed: 0, + items_total: 0, + }, + epoch: 0, + epoch_total: 0, + iteration: 0, + } + } +} diff --git a/burn-train/src/metric/callback/cli.rs b/burn-train/src/renderer/cli.rs similarity index 100% rename from burn-train/src/metric/callback/cli.rs rename to burn-train/src/renderer/cli.rs diff --git a/burn-train/src/metric/callback/mod.rs b/burn-train/src/renderer/mod.rs similarity index 99% rename from burn-train/src/metric/callback/mod.rs rename to burn-train/src/renderer/mod.rs index 9ff7f99d7e..9002184326 100644 --- a/burn-train/src/metric/callback/mod.rs +++ b/burn-train/src/renderer/mod.rs @@ -1,5 +1,4 @@ mod base; - pub use base::*; #[cfg(not(feature = "tui"))] diff --git a/burn-train/src/metric/callback/tui/base.rs b/burn-train/src/renderer/tui/base.rs similarity index 100% rename from burn-train/src/metric/callback/tui/base.rs rename to burn-train/src/renderer/tui/base.rs diff --git a/burn-train/src/metric/callback/tui/controls.rs b/burn-train/src/renderer/tui/controls.rs similarity index 100% rename from burn-train/src/metric/callback/tui/controls.rs rename to burn-train/src/renderer/tui/controls.rs diff --git a/burn-train/src/metric/callback/tui/full_history.rs b/burn-train/src/renderer/tui/full_history.rs similarity index 100% rename from burn-train/src/metric/callback/tui/full_history.rs rename to burn-train/src/renderer/tui/full_history.rs diff --git a/burn-train/src/metric/callback/tui/metric_numeric.rs b/burn-train/src/renderer/tui/metric_numeric.rs similarity index 99% rename from burn-train/src/metric/callback/tui/metric_numeric.rs rename to burn-train/src/renderer/tui/metric_numeric.rs index 3e8200067c..ccae8e295c 100644 --- a/burn-train/src/metric/callback/tui/metric_numeric.rs +++ b/burn-train/src/renderer/tui/metric_numeric.rs @@ -1,4 +1,4 @@ -use crate::metric::callback::TrainingProgress; +use crate::renderer::TrainingProgress; use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame}; use crossterm::event::{Event, KeyCode}; diff --git a/burn-train/src/metric/callback/tui/metric_text.rs b/burn-train/src/renderer/tui/metric_text.rs similarity index 100% rename from burn-train/src/metric/callback/tui/metric_text.rs rename to burn-train/src/renderer/tui/metric_text.rs diff --git a/burn-train/src/metric/callback/tui/mod.rs b/burn-train/src/renderer/tui/mod.rs similarity index 100% rename from burn-train/src/metric/callback/tui/mod.rs rename to burn-train/src/renderer/tui/mod.rs diff --git a/burn-train/src/metric/callback/tui/plot_utils.rs b/burn-train/src/renderer/tui/plot_utils.rs similarity index 100% rename from burn-train/src/metric/callback/tui/plot_utils.rs rename to burn-train/src/renderer/tui/plot_utils.rs diff --git a/burn-train/src/metric/callback/tui/popup.rs b/burn-train/src/renderer/tui/popup.rs similarity index 100% rename from burn-train/src/metric/callback/tui/popup.rs rename to burn-train/src/renderer/tui/popup.rs diff --git a/burn-train/src/metric/callback/tui/progress.rs b/burn-train/src/renderer/tui/progress.rs similarity index 99% rename from burn-train/src/metric/callback/tui/progress.rs rename to burn-train/src/renderer/tui/progress.rs index b8529f55f8..b41cc4b2dc 100644 --- a/burn-train/src/metric/callback/tui/progress.rs +++ b/burn-train/src/renderer/tui/progress.rs @@ -1,5 +1,6 @@ +use crate::renderer::TrainingProgress; + use super::TerminalFrame; -use crate::metric::callback::TrainingProgress; use ratatui::{ prelude::{Alignment, Constraint, Direction, Layout, Rect}, style::{Color, Style, Stylize}, diff --git a/burn-train/src/metric/callback/tui/recent_history.rs b/burn-train/src/renderer/tui/recent_history.rs similarity index 100% rename from burn-train/src/metric/callback/tui/recent_history.rs rename to burn-train/src/renderer/tui/recent_history.rs diff --git a/burn-train/src/metric/callback/tui/renderer.rs b/burn-train/src/renderer/tui/renderer.rs similarity index 97% rename from burn-train/src/metric/callback/tui/renderer.rs rename to burn-train/src/renderer/tui/renderer.rs index b00c0fb85f..015e3e88ca 100644 --- a/burn-train/src/metric/callback/tui/renderer.rs +++ b/burn-train/src/renderer/tui/renderer.rs @@ -1,5 +1,5 @@ -use crate::metric::callback::tui::NumericMetricsState; -use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress}; +use crate::renderer::{tui::NumericMetricsState, MetricsRenderer}; +use crate::renderer::{MetricState, TrainingProgress}; use crate::TrainingInterrupter; use crossterm::{ event::{self, Event, KeyCode}, diff --git a/burn-train/src/metric/callback/tui/status.rs b/burn-train/src/renderer/tui/status.rs similarity index 98% rename from burn-train/src/metric/callback/tui/status.rs rename to burn-train/src/renderer/tui/status.rs index e03be81833..3519d217cf 100644 --- a/burn-train/src/metric/callback/tui/status.rs +++ b/burn-train/src/renderer/tui/status.rs @@ -1,5 +1,5 @@ use super::TerminalFrame; -use crate::metric::callback::TrainingProgress; +use crate::renderer::TrainingProgress; use ratatui::{ prelude::{Alignment, Rect}, style::{Color, Style, Stylize}, diff --git a/examples/custom-renderer/src/lib.rs b/examples/custom-renderer/src/lib.rs index bef8cee398..3ce942b34f 100644 --- a/examples/custom-renderer/src/lib.rs +++ b/examples/custom-renderer/src/lib.rs @@ -1,5 +1,5 @@ use burn::data::dataset::source::huggingface::MNISTDataset; -use burn::train::metric::callback::{MetricState, MetricsRenderer, TrainingProgress}; +use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress}; use burn::train::LearnerBuilder; use burn::{ config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, diff --git a/examples/guide/src/training.rs b/examples/guide/src/training.rs index 4eb3741ee6..aca4a66d97 100644 --- a/examples/guide/src/training.rs +++ b/examples/guide/src/training.rs @@ -88,10 +88,10 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, device: B .build(MNISTDataset::test()); let learner = LearnerBuilder::new(artifact_dir) - .metric_train_plot(AccuracyMetric::new()) - .metric_valid_plot(AccuracyMetric::new()) - .metric_train_plot(LossMetric::new()) - .metric_valid_plot(LossMetric::new()) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(1, CompactRecorder::new()) .devices(vec![device]) .num_epochs(config.num_epochs) diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index 6d4b4e83a1..2e4504366a 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -58,16 +58,16 @@ pub fn run(device: B::Device) { // Model let learner = LearnerBuilder::new(ARTIFACT_DIR) - .metric_train_plot(AccuracyMetric::new()) - .metric_valid_plot(AccuracyMetric::new()) - .metric_train_plot(CpuUse::new()) - .metric_valid_plot(CpuUse::new()) - .metric_train_plot(CpuMemory::new()) - .metric_valid_plot(CpuMemory::new()) - .metric_train_plot(CpuTemperature::new()) - .metric_valid_plot(CpuTemperature::new()) - .metric_train_plot(LossMetric::new()) - .metric_valid_plot(LossMetric::new()) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(CpuUse::new()) + .metric_valid_numeric(CpuUse::new()) + .metric_train_numeric(CpuMemory::new()) + .metric_valid_numeric(CpuMemory::new()) + .metric_train_numeric(CpuTemperature::new()) + .metric_valid_numeric(CpuTemperature::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(1, CompactRecorder::new()) .devices(vec![device]) .num_epochs(config.num_epochs) diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index a128426a26..99eb337593 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -95,9 +95,9 @@ pub fn train( .metric_valid(CUDAMetric::new()) .metric_train(AccuracyMetric::new()) .metric_valid(AccuracyMetric::new()) - .metric_train_plot(LossMetric::new()) - .metric_valid_plot(LossMetric::new()) - .metric_train_plot(LearningRateMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(2, CompactRecorder::new()) .devices(vec![device]) .num_epochs(config.num_epochs) diff --git a/examples/text-generation/src/training.rs b/examples/text-generation/src/training.rs index b4442d067f..7b92662e48 100644 --- a/examples/text-generation/src/training.rs +++ b/examples/text-generation/src/training.rs @@ -70,11 +70,11 @@ pub fn train + 'static>( let learner = LearnerBuilder::new(artifact_dir) .metric_train(CUDAMetric::new()) .metric_valid(CUDAMetric::new()) - .metric_train_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) - .metric_valid_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) + .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) + .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_train(LossMetric::new()) .metric_valid(LossMetric::new()) - .metric_train_plot(LearningRateMetric::new()) + .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(2, CompactRecorder::new()) .devices(vec![device]) .grads_accumulation(accum)