Skip to content

Commit

Permalink
refactor: multi-thread logger -> async logger
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Aug 27, 2022
1 parent 05eb2f9 commit 9ed0b0b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use burn::tensor::af::relu;
use burn::tensor::back::{ad, Backend};
use burn::tensor::losses::cross_entropy_with_logits;
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::train::logger::{CLILogger, MultiThreadLogger};
use burn::train::logger::{AsyncLogger, CLILogger};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, Metric};
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
use std::sync::Arc;
Expand Down Expand Up @@ -180,11 +180,11 @@ fn run<B: ad::Backend>(device: B::Device) {

let learner = ClassificationLearner::new(model);

let logger_train = Box::new(MultiThreadLogger::new(Box::new(CLILogger::new(
let logger_train = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
metrics(),
"Train".to_string(),
))));
let logger_valid = Box::new(MultiThreadLogger::new(Box::new(CLILogger::new(
let logger_valid = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
metrics(),
"Valid".to_string(),
))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ enum Message<T> {
Clear,
}

pub struct MultiThreadLogger<T> {
pub struct AsyncLogger<T> {
sender: mpsc::Sender<Message<T>>,
}

Expand All @@ -33,7 +33,7 @@ impl<T> LoggerThread<T> {
}
}

impl<T: Send + Sync + 'static> MultiThreadLogger<T> {
impl<T: Send + Sync + 'static> AsyncLogger<T> {
pub fn new(logger: Box<dyn Logger<T>>) -> Self {
let (sender, receiver) = mpsc::channel();
let thread = LoggerThread::new(Mutex::new(logger), receiver);
Expand All @@ -44,7 +44,7 @@ impl<T: Send + Sync + 'static> MultiThreadLogger<T> {
}
}

impl<T: Send> Logger<T> for MultiThreadLogger<T> {
impl<T: Send> Logger<T> for AsyncLogger<T> {
fn log(&mut self, item: T) {
self.sender.send(Message::Log(item)).unwrap();
}
Expand Down
4 changes: 2 additions & 2 deletions src/train/logger/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod async_logger;
mod cli;
mod logger;
mod multi_thread;

pub use async_logger::*;
pub use cli::*;
pub use logger::*;
pub use multi_thread::*;

0 comments on commit 9ed0b0b

Please sign in to comment.