From 620b86de9892a43ff927718cc89abf0a34e91976 Mon Sep 17 00:00:00 2001
From: Nathaniel Simard
Date: Tue, 10 Oct 2023 13:27:03 -0400
Subject: [PATCH] Feat training events (#857)
---
burn-book/src/basic-workflow/training.md | 8 +-
burn-train/src/callback/async_callback.rs | 97 ------
burn-train/src/callback/base.rs | 43 ---
burn-train/src/callback/mod.rs | 5 -
burn-train/src/collector/async_collector.rs | 118 ++++++++
burn-train/src/collector/base.rs | 78 +++++
burn-train/src/collector/metrics/base.rs | 131 ++++++++
burn-train/src/collector/metrics/mod.rs | 3 +
burn-train/src/collector/mod.rs | 8 +
burn-train/src/components.rs | 14 +-
burn-train/src/info/aggregates.rs | 163 ++++++++++
burn-train/src/info/metrics.rs | 253 ++++++++++++++++
burn-train/src/info/mod.rs | 5 +
burn-train/src/learner/base.rs | 2 +-
burn-train/src/learner/builder.rs | 89 +++---
burn-train/src/learner/epoch.rs | 24 +-
burn-train/src/learner/train_val.rs | 8 +-
burn-train/src/lib.rs | 10 +-
burn-train/src/logger/async_logger.rs | 15 +
burn-train/src/logger/metric.rs | 11 +-
burn-train/src/metric/callback/base.rs | 285 ------------------
burn-train/src/metric/mod.rs | 5 +-
burn-train/src/metric/state.rs | 2 +-
burn-train/src/renderer/base.rs | 75 +++++
.../src/{metric/callback => renderer}/cli.rs | 0
.../src/{metric/callback => renderer}/mod.rs | 1 -
.../{metric/callback => renderer}/tui/base.rs | 0
.../callback => renderer}/tui/controls.rs | 0
.../callback => renderer}/tui/full_history.rs | 0
.../tui/metric_numeric.rs | 2 +-
.../callback => renderer}/tui/metric_text.rs | 0
.../{metric/callback => renderer}/tui/mod.rs | 0
.../callback => renderer}/tui/plot_utils.rs | 0
.../callback => renderer}/tui/popup.rs | 0
.../callback => renderer}/tui/progress.rs | 3 +-
.../tui/recent_history.rs | 0
.../callback => renderer}/tui/renderer.rs | 4 +-
.../callback => renderer}/tui/status.rs | 2 +-
examples/custom-renderer/src/lib.rs | 2 +-
examples/guide/src/training.rs | 8 +-
examples/mnist/src/training.rs | 20 +-
examples/text-classification/src/training.rs | 6 +-
examples/text-generation/src/training.rs | 6 +-
43 files changed, 959 insertions(+), 547 deletions(-)
delete mode 100644 burn-train/src/callback/async_callback.rs
delete mode 100644 burn-train/src/callback/base.rs
delete mode 100644 burn-train/src/callback/mod.rs
create mode 100644 burn-train/src/collector/async_collector.rs
create mode 100644 burn-train/src/collector/base.rs
create mode 100644 burn-train/src/collector/metrics/base.rs
create mode 100644 burn-train/src/collector/metrics/mod.rs
create mode 100644 burn-train/src/collector/mod.rs
create mode 100644 burn-train/src/info/aggregates.rs
create mode 100644 burn-train/src/info/metrics.rs
create mode 100644 burn-train/src/info/mod.rs
delete mode 100644 burn-train/src/metric/callback/base.rs
create mode 100644 burn-train/src/renderer/base.rs
rename burn-train/src/{metric/callback => renderer}/cli.rs (100%)
rename burn-train/src/{metric/callback => renderer}/mod.rs (99%)
rename burn-train/src/{metric/callback => renderer}/tui/base.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/controls.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/full_history.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/metric_numeric.rs (99%)
rename burn-train/src/{metric/callback => renderer}/tui/metric_text.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/mod.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/plot_utils.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/popup.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/progress.rs (99%)
rename burn-train/src/{metric/callback => renderer}/tui/recent_history.rs (100%)
rename burn-train/src/{metric/callback => renderer}/tui/renderer.rs (97%)
rename burn-train/src/{metric/callback => renderer}/tui/status.rs (98%)
diff --git a/burn-book/src/basic-workflow/training.md b/burn-book/src/basic-workflow/training.md
index 7cf2e195f9..286270342b 100644
--- a/burn-book/src/basic-workflow/training.md
+++ b/burn-book/src/basic-workflow/training.md
@@ -111,10 +111,10 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, device: B
.build(MNISTDataset::test());
let learner = LearnerBuilder::new(artifact_dir)
- .metric_train_plot(AccuracyMetric::new())
- .metric_valid_plot(AccuracyMetric::new())
- .metric_train_plot(LossMetric::new())
- .metric_valid_plot(LossMetric::new())
+ .metric_train_numeric(AccuracyMetric::new())
+ .metric_valid_numeric(AccuracyMetric::new())
+ .metric_train_numeric(LossMetric::new())
+ .metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(1, CompactRecorder::new())
.devices(vec![device])
.num_epochs(config.num_epochs)
diff --git a/burn-train/src/callback/async_callback.rs b/burn-train/src/callback/async_callback.rs
deleted file mode 100644
index 2e0fcbc294..0000000000
--- a/burn-train/src/callback/async_callback.rs
+++ /dev/null
@@ -1,97 +0,0 @@
-use super::{LearnerCallback, LearnerItem};
-use std::{sync::mpsc, thread::JoinHandle};
-
-enum Message {
- LogTrain(LearnerItem),
- LogValid(LearnerItem),
- ClearTrain(usize),
- ClearValid(usize),
- End,
-}
-
-/// Async trainer callback tracker.
-pub struct AsyncTrainerCallback {
- sender: mpsc::Sender>,
- handler: Option>,
-}
-
-#[derive(new)]
-struct CallbackThread {
- callback: C,
- receiver: mpsc::Receiver>,
-}
-
-impl CallbackThread
-where
- C: LearnerCallback,
-{
- fn run(mut self) {
- for item in self.receiver.iter() {
- match item {
- Message::LogTrain(item) => {
- self.callback.on_train_item(item);
- }
- Message::ClearTrain(epoch) => {
- self.callback.on_train_end_epoch(epoch);
- }
- Message::LogValid(item) => {
- self.callback.on_valid_item(item);
- }
- Message::ClearValid(epoch) => {
- self.callback.on_valid_end_epoch(epoch);
- }
- Message::End => {
- return;
- }
- }
- }
- }
-}
-
-impl AsyncTrainerCallback {
- /// Create a new async trainer callback.
- pub fn new(callback: C) -> Self
- where
- C: LearnerCallback + 'static,
- {
- let (sender, receiver) = mpsc::channel();
- let thread = CallbackThread::new(callback, receiver);
-
- let handler = std::thread::spawn(move || thread.run());
- let handler = Some(handler);
-
- Self { sender, handler }
- }
-}
-
-impl LearnerCallback for AsyncTrainerCallback {
- type ItemTrain = T;
- type ItemValid = V;
-
- fn on_train_item(&mut self, item: LearnerItem) {
- self.sender.send(Message::LogTrain(item)).unwrap();
- }
-
- fn on_valid_item(&mut self, item: LearnerItem) {
- self.sender.send(Message::LogValid(item)).unwrap();
- }
-
- fn on_train_end_epoch(&mut self, epoch: usize) {
- self.sender.send(Message::ClearTrain(epoch)).unwrap();
- }
-
- fn on_valid_end_epoch(&mut self, epoch: usize) {
- self.sender.send(Message::ClearValid(epoch)).unwrap();
- }
-}
-
-impl Drop for AsyncTrainerCallback {
- fn drop(&mut self) {
- self.sender.send(Message::End).unwrap();
- let handler = self.handler.take();
-
- if let Some(handler) = handler {
- handler.join().unwrap();
- }
- }
-}
diff --git a/burn-train/src/callback/base.rs b/burn-train/src/callback/base.rs
deleted file mode 100644
index 6750522104..0000000000
--- a/burn-train/src/callback/base.rs
+++ /dev/null
@@ -1,43 +0,0 @@
-use burn_core::{data::dataloader::Progress, LearningRate};
-
-/// The base trait for trainer callbacks.
-pub trait LearnerCallback: Send {
- /// Training item.
- type ItemTrain;
- /// Validation item.
- type ItemValid;
-
- /// Called when a training item is logged.
- fn on_train_item(&mut self, _item: LearnerItem) {}
-
- /// Called when a validation item is logged.
- fn on_valid_item(&mut self, _item: LearnerItem) {}
-
- /// Called when a training epoch is finished.
- fn on_train_end_epoch(&mut self, _epoch: usize) {}
-
- /// Called when a validation epoch is finished.
- fn on_valid_end_epoch(&mut self, _epoch: usize) {}
-}
-
-/// A learner item.
-#[derive(new)]
-pub struct LearnerItem {
- /// The item.
- pub item: T,
-
- /// The progress.
- pub progress: Progress,
-
- /// The epoch.
- pub epoch: usize,
-
- /// The total number of epochs.
- pub epoch_total: usize,
-
- /// The iteration.
- pub iteration: usize,
-
- /// The learning rate.
- pub lr: Option,
-}
diff --git a/burn-train/src/callback/mod.rs b/burn-train/src/callback/mod.rs
deleted file mode 100644
index 809b581de4..0000000000
--- a/burn-train/src/callback/mod.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-mod async_callback;
-mod base;
-
-pub use async_callback::*;
-pub use base::*;
diff --git a/burn-train/src/collector/async_collector.rs b/burn-train/src/collector/async_collector.rs
new file mode 100644
index 0000000000..eb4f75d705
--- /dev/null
+++ b/burn-train/src/collector/async_collector.rs
@@ -0,0 +1,118 @@
+use super::EventCollector;
+use crate::{Aggregate, Direction, Event, Split};
+use std::{sync::mpsc, thread::JoinHandle};
+
+enum Message {
+ OnEventTrain(Event),
+ OnEventValid(Event),
+ End,
+ FindEpoch(
+ String,
+ Aggregate,
+ Direction,
+ Split,
+ mpsc::SyncSender