Skip to content

Commit

Permalink
Feat training events (tracel-ai#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 10, 2023
1 parent 097fd95 commit 620b86d
Show file tree
Hide file tree
Showing 43 changed files with 959 additions and 547 deletions.
8 changes: 4 additions & 4 deletions burn-book/src/basic-workflow/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
.build(MNISTDataset::test());

let learner = LearnerBuilder::new(artifact_dir)
.metric_train_plot(AccuracyMetric::new())
.metric_valid_plot(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(1, CompactRecorder::new())
.devices(vec![device])
.num_epochs(config.num_epochs)
Expand Down
97 changes: 0 additions & 97 deletions burn-train/src/callback/async_callback.rs

This file was deleted.

43 changes: 0 additions & 43 deletions burn-train/src/callback/base.rs

This file was deleted.

5 changes: 0 additions & 5 deletions burn-train/src/callback/mod.rs

This file was deleted.

118 changes: 118 additions & 0 deletions burn-train/src/collector/async_collector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use super::EventCollector;
use crate::{Aggregate, Direction, Event, Split};
use std::{sync::mpsc, thread::JoinHandle};

enum Message<T, V> {
OnEventTrain(Event<T>),
OnEventValid(Event<V>),
End,
FindEpoch(
String,
Aggregate,
Direction,
Split,
mpsc::SyncSender<Option<usize>>,
),
}

/// Async [event collector](EventCollector).
///
/// This will create a worker thread where all the computation is done ensuring that the training loop is
/// never blocked by metric calculation.
pub struct AsyncEventCollector<T, V> {
sender: mpsc::Sender<Message<T, V>>,
handler: Option<JoinHandle<()>>,
}

#[derive(new)]
struct WorkerThread<C, T, V> {
collector: C,
receiver: mpsc::Receiver<Message<T, V>>,
}

impl<C, T, V> WorkerThread<C, T, V>
where
C: EventCollector<ItemTrain = T, ItemValid = V>,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::End => {
return;
}
Message::FindEpoch(name, aggregate, direction, split, sender) => {
let response = self
.collector
.find_epoch(&name, aggregate, direction, split);
sender.send(response).unwrap();
}
Message::OnEventTrain(event) => self.collector.on_event_train(event),
Message::OnEventValid(event) => self.collector.on_event_valid(event),
}
}
}
}

impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncEventCollector<T, V> {
/// Create a new async [event collector](EventCollector).
pub fn new<C>(collector: C) -> Self
where
C: EventCollector<ItemTrain = T, ItemValid = V> + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(collector, receiver);

let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);

Self { sender, handler }
}
}

impl<T: Send, V: Send> EventCollector for AsyncEventCollector<T, V> {
type ItemTrain = T;
type ItemValid = V;

fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
self.sender.send(Message::OnEventTrain(event)).unwrap();
}

fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
self.sender.send(Message::OnEventValid(event)).unwrap();
}

fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindEpoch(
name.to_string(),
aggregate,
direction,
split,
sender,
))
.unwrap();

match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Async server crashed: {:?}", err),
}
}
}

impl<T, V> Drop for AsyncEventCollector<T, V> {
fn drop(&mut self) {
self.sender.send(Message::End).unwrap();
let handler = self.handler.take();

if let Some(handler) = handler {
handler.join().unwrap();
}
}
}
78 changes: 78 additions & 0 deletions burn-train/src/collector/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use burn_core::{data::dataloader::Progress, LearningRate};

/// Event happening during the training/validation process.
pub enum Event<T> {
/// Signal that an item have been processed.
ProcessedItem(LearnerItem<T>),
/// Signal the end of an epoch.
EndEpoch(usize),
}

/// Defines how training and validation events are collected.
///
/// This trait also exposes methods that uses the collected data to compute useful information.
pub trait EventCollector: Send {
/// Training item.
type ItemTrain;
/// Validation item.
type ItemValid;

/// Collect the training event.
fn on_event_train(&mut self, event: Event<Self::ItemTrain>);

/// Collect the validaion event.
fn on_event_valid(&mut self, event: Event<Self::ItemValid>);

/// Find the epoch following the given criteria from the collected data.
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize>;
}

/// How to aggregate the metric.
pub enum Aggregate {
/// Compute the average.
Mean,
}

/// The split to use.
pub enum Split {
/// The training split.
Train,
/// The validation split.
Valid,
}

/// The direction of the query.
pub enum Direction {
/// Lower is better.
Lowest,
/// Higher is better.
Highest,
}

/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
/// The item.
pub item: T,

/// The progress.
pub progress: Progress,

/// The epoch.
pub epoch: usize,

/// The total number of epochs.
pub epoch_total: usize,

/// The iteration.
pub iteration: usize,

/// The learning rate.
pub lr: Option<LearningRate>,
}
Loading

0 comments on commit 620b86d

Please sign in to comment.