Skip to content

Commit

Permalink
Add configurable application logger to learner builder (tracel-ai#1774)
Browse files Browse the repository at this point in the history
* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct

* Remove unused log module and renames, fmt

* Renamed tracing subscriber logger

* renamed to application logger installer

* book learner configuration update update

* fix typo

* unused import
  • Loading branch information
jwric authored May 16, 2024
1 parent 7ab2ba1 commit 8de05e1
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 64 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/learner.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The learner builder provides numerous options when it comes to configurations.
| Num Epochs | Set the number of epochs. |
| Devices | Set the devices to be used |
| Checkpoint | Restart training from a checkpoint |
| Application logging | Configure the application logging installer (default is writing to `experiment.log`) |

When the builder is configured at your liking, you can then move forward to build the learner. The
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note
Expand Down
67 changes: 67 additions & 0 deletions crates/burn-train/src/learner/application_logger.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::path::Path;
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{registry, Layer};

/// This trait is used to install an application logger.
pub trait ApplicationLoggerInstaller {
/// Install the application logger.
fn install(&self) -> Result<(), String>;
}

/// This struct is used to install a local file application logger to output logs to a given file path.
pub struct FileApplicationLoggerInstaller {
path: String,
}

impl FileApplicationLoggerInstaller {
/// Create a new file application logger.
pub fn new(path: &str) -> Self {
Self {
path: path.to_string(),
}
}
}

impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
fn install(&self) -> Result<(), String> {
let path = Path::new(&self.path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(writer)
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));

if registry().with(layer).try_init().is_err() {
return Err("Failed to install the file logger.".to_string());
}

let hook = std::panic::take_hook();
let file_path: String = self.path.to_owned();

std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
);
hook(info);
}));

Ok(())
}
}
32 changes: 18 additions & 14 deletions crates/burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::HashSet;
use std::rc::Rc;

use super::log::install_file_logger;
use super::Learner;
use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
Expand All @@ -15,7 +14,10 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
use crate::{
ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
LearnerSummaryConfig,
};
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::AutodiffModule;
use burn_core::optim::Optimizer;
Expand Down Expand Up @@ -50,7 +52,7 @@ where
metrics: Metrics<T, V>,
event_store: LogEventStore,
interrupter: TrainingInterrupter,
log_to_file: bool,
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
Expand Down Expand Up @@ -84,7 +86,9 @@ where
event_store: LogEventStore::default(),
renderer: None,
interrupter: TrainingInterrupter::new(),
log_to_file: true,
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
format!("{}/experiment.log", directory).as_str(),
))),
num_loggers: 0,
checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder()
Expand Down Expand Up @@ -233,8 +237,11 @@ where
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
pub fn log_to_file(mut self, enabled: bool) -> Self {
self.log_to_file = enabled;
pub fn with_application_logger(
mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self
}

Expand All @@ -258,7 +265,7 @@ where
format!("{}/checkpoint", self.directory).as_str(),
"optim",
);
let checkpointer_scheduler = FileCheckpointer::new(
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
recorder,
format!("{}/checkpoint", self.directory).as_str(),
"scheduler",
Expand Down Expand Up @@ -309,8 +316,10 @@ where
O::Record: 'static,
S::Record: 'static,
{
if self.log_to_file {
self.init_logger();
if self.tracing_logger.is_some() {
if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
log::warn!("Failed to install the experiment logger: {}", e);
}
}
let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
Expand Down Expand Up @@ -360,9 +369,4 @@ where
summary,
}
}

fn init_logger(&self) {
let file_path = format!("{}/experiment.log", self.directory);
install_file_logger(file_path.as_str());
}
}
47 changes: 0 additions & 47 deletions crates/burn-train/src/learner/log.rs

This file was deleted.

4 changes: 2 additions & 2 deletions crates/burn-train/src/learner/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod application_logger;
mod base;
mod builder;
mod classification;
Expand All @@ -8,8 +9,7 @@ mod step;
mod summary;
mod train_val;

pub(crate) mod log;

pub use application_logger::*;
pub use base::*;
pub use builder::*;
pub use classification::*;
Expand Down
2 changes: 1 addition & 1 deletion examples/custom-renderer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
.devices(vec![device])
.num_epochs(config.num_epochs)
.renderer(CustomRenderer {})
.log_to_file(false);
.with_application_logger(None);
// can be used to interrupt training
let _interrupter = builder.interrupter();

Expand Down

0 comments on commit 8de05e1

Please sign in to comment.