Skip to content

Commit

Permalink
Refactor/metric (tracel-ai#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 23, 2022
1 parent d62f2b0 commit 508cfd2
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 63 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ burn-derive = { path = "./burn-derive", version = "0.1.0" }

# Metrics
nvml-wrapper = "0.8"
textplots = "0.8"

# Console
indicatif = "0.17"
Expand Down
21 changes: 12 additions & 9 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use burn::tensor::af::relu;
use burn::tensor::back::{ad, Backend};
use burn::tensor::losses::cross_entropy_with_logits;
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::train::logger::CLILogger;
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, RunningMetric};
use burn::train::logger::{CLILogger, TextPlot};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, Metric};
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
use std::sync::Arc;

Expand Down Expand Up @@ -159,17 +159,20 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
fn run<B: ad::Backend>(device: B::Device) {
let batch_size = 64;
let learning_rate = 9.5e-2;
let num_epochs = 20;
let num_workers = 8;
let num_layers = 4;
let hidden_dim = 1024;
let seed = 42;
let metrics = || -> Vec<Box<dyn RunningMetric<ClassificationOutput<B>>>> {
let metrics = || -> Vec<Box<dyn Metric<ClassificationOutput<B>>>> {
vec![
Box::new(LossMetric::new()),
Box::new(AccuracyMetric::new()),
Box::new(TextPlot::new(LossMetric::new())),
Box::new(TextPlot::new(AccuracyMetric::new())),
Box::new(CUDAMetric::new()),
]
};

let mut model: Model<B> = Model::new(784, 1024, 3, 10);
let mut model: Model<B> = Model::new(784, hidden_dim, num_layers, 10);
model.to_device(device);
println!(
"Training '{}' with {} params on backend {} {:?}",
Expand Down Expand Up @@ -206,8 +209,8 @@ fn run<B: ad::Backend>(device: B::Device) {
let logger_test = Box::new(CLILogger::new(metrics(), "Test".to_string()));

let trainer = SupervisedTrainer::new(
dataloader_train,
dataloader_test.clone(),
dataloader_train.clone(),
dataloader_train.clone(),
dataloader_test.clone(),
logger_train,
logger_valid,
Expand All @@ -216,7 +219,7 @@ fn run<B: ad::Backend>(device: B::Device) {
optim,
);

trainer.run(20);
trainer.run(num_epochs);
}

fn main() {
Expand Down
30 changes: 16 additions & 14 deletions src/train/logger/cli.rs → src/train/logger/cli/cli.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use super::{LogItem, Logger};
use crate::train::metric::{LossMetric, RunningMetric, RunningMetricResult};
use crate::train::{
logger::{LogItem, Logger},
metric::{LossMetric, Metric, MetricStateDyn},
};
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use std::fmt::Write;

pub struct CLILogger<T> {
metrics: Vec<Box<dyn RunningMetric<T>>>,
metrics: Vec<Box<dyn Metric<T>>>,
name: String,
pb: ProgressBar,
}

impl<T> Logger<T> for CLILogger<T>
where
LossMetric: RunningMetric<T>,
LossMetric: Metric<T>,
{
fn log(&mut self, item: LogItem<T>) {
let metrics = self.update_metrics(&item);
Expand All @@ -21,7 +23,6 @@ where
let template = self.register_template_progress(&item, template);

let style = ProgressStyle::with_template(&template).unwrap();
let style = self.register_style_metrics(&metrics, style);
let style = self.register_style_progress(&item, style);

if self.pb.length() == Some(0) {
Expand All @@ -31,6 +32,7 @@ where
self.pb.set_style(style.progress_chars("#>-"));
self.pb.set_position(item.iteration as u64);
self.pb.set_length(item.iteration_total as u64);
self.pb.tick();
}

fn clear(&mut self) {
Expand All @@ -44,15 +46,15 @@ where
}

impl<T> CLILogger<T> {
pub fn new(metrics: Vec<Box<dyn RunningMetric<T>>>, name: String) -> Self {
pub fn new(metrics: Vec<Box<dyn Metric<T>>>, name: String) -> Self {
Self {
metrics,
name,
pb: ProgressBar::new(0),
}
}

pub fn update_metrics(&mut self, item: &LogItem<T>) -> Vec<RunningMetricResult> {
pub fn update_metrics(&mut self, item: &LogItem<T>) -> Vec<MetricStateDyn> {
let mut metrics_result = Vec::with_capacity(self.metrics.len());

for metric in &mut self.metrics {
Expand Down Expand Up @@ -82,14 +84,14 @@ impl<T> CLILogger<T> {

pub fn register_template_metrics(
&self,
metrics: &Vec<RunningMetricResult>,
metrics: &Vec<MetricStateDyn>,
template: String,
) -> String {
let mut template = template;
let mut metrics_keys = Vec::new();

for i in 0..metrics.len() {
metrics_keys.push(format!(" - {{metric{}}}", i));
for metric in metrics {
metrics_keys.push(format!(" - {}: {}", metric.name(), metric.pretty()));
}

if metrics.len() > 0 {
Expand Down Expand Up @@ -127,7 +129,7 @@ impl<T> CLILogger<T> {

pub fn register_style_metrics(
&self,
items: &Vec<RunningMetricResult>,
items: &Vec<MetricStateDyn>,
style: ProgressStyle,
) -> ProgressStyle {
let mut style = style;
Expand Down Expand Up @@ -155,10 +157,10 @@ impl<T> CLILogger<T> {
&self,
key: &'static str,
style: ProgressStyle,
metric_result: &RunningMetricResult,
metric_result: &MetricStateDyn,
) -> ProgressStyle {
let formatted = metric_result.formatted.clone();
let name = metric_result.name.clone();
let formatted = metric_result.pretty();
let name = metric_result.name();

self.register_key_item(key, style, name, formatted)
}
Expand Down
5 changes: 5 additions & 0 deletions src/train/logger/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod cli;
mod plot;

pub use cli::*;
pub use plot::*;
90 changes: 90 additions & 0 deletions src/train/logger/cli/plot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use crate::train::metric::{Metric, MetricState, MetricStateDyn, NumericMetric};
use textplots::{Chart, Plot, Shape};

pub struct TextPlot<M: NumericMetric> {
metric: M,
values: Vec<f32>,
}

#[derive(new)]
pub struct TextPlotState {
inner: MetricStateDyn,
plot: String,
}

impl MetricState for TextPlotState {
fn name(&self) -> String {
self.inner.name()
}

fn pretty(&self) -> String {
format!("{}{}", self.inner.pretty(), self.plot)
}

fn serialize(&self) -> String {
self.inner.serialize()
}
}

impl<M: NumericMetric> TextPlot<M> {
pub fn new(metric: M) -> Self {
Self {
metric,
values: Vec::new(),
}
}
}

impl<M, T> Metric<T> for TextPlot<M>
where
M: Metric<T> + NumericMetric,
{
fn update(&mut self, item: &T) -> MetricStateDyn {
let state = self.metric.update(item);
self.values.push(self.metric.value() as f32);

let graph = Chart::new(256, 32, 0.0, self.values.len() as f32)
.lineplot(&Shape::Lines(&smooth_values(&self.values, 256)))
.to_string();

Box::new(TextPlotState::new(state, format!("\n\n{}", graph)))
}

fn clear(&mut self) {
self.metric.clear();
}
}

fn smooth_values(values: &Vec<f32>, size_appox: usize) -> Vec<(f32, f32)> {
let batch_size = values.len() / size_appox;
if batch_size == 0 {
return values
.iter()
.enumerate()
.map(|(i, v)| (i as f32, *v as f32))
.collect();
}

let mut output = Vec::with_capacity(size_appox);
let mut current_sum = 0.0;
let mut current_count = 0;

for value in values.iter() {
current_sum += value;
current_count += 1;

if current_count >= batch_size {
output.push(current_sum / current_count as f32);
}
}

if current_count > 0 {
output.push(current_sum / current_count as f32);
}

output
.iter()
.enumerate()
.map(|(i, v)| (i as f32, *v as f32))
.collect()
}
16 changes: 11 additions & 5 deletions src/train/metric/acc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::RunningMetricResult;
use crate::tensor::back::Backend;
use crate::tensor::Tensor;
use crate::train::metric::RunningMetric;
use crate::train::metric::{Metric, MetricStateDyn, NumericMetric};

pub struct AccuracyMetric {
current: f64,
Expand All @@ -19,8 +19,14 @@ impl AccuracyMetric {
}
}

impl<B: Backend> RunningMetric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B, 2>)) -> RunningMetricResult {
impl NumericMetric for AccuracyMetric {
fn value(&self) -> f64 {
self.current * 100.0
}
}

impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B, 2>)) -> MetricStateDyn {
let (outputs, targets) = batch;
let logits_outputs = outputs.argmax(1).to_data();
let logits_targets = targets.argmax(1).to_data();
Expand Down Expand Up @@ -49,12 +55,12 @@ impl<B: Backend> RunningMetric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric
100.0 * self.current
);

RunningMetricResult {
Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current,
}
})
}

fn clear(&mut self) {
Expand Down
10 changes: 5 additions & 5 deletions src/train/metric/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::RunningMetricResult;
use crate::train::metric::RunningMetric;
use crate::train::metric::{Metric, MetricState};
use nvml_wrapper::Nvml;

pub struct CUDAMetric {
Expand All @@ -14,8 +14,8 @@ impl CUDAMetric {
}
}

impl<T> RunningMetric<T> for CUDAMetric {
fn update(&mut self, _item: &T) -> RunningMetricResult {
impl<T> Metric<T> for CUDAMetric {
fn update(&mut self, _item: &T) -> Box<dyn MetricState> {
let name = String::from("Cuda");

let mut formatted = String::new();
Expand All @@ -41,12 +41,12 @@ impl<T> RunningMetric<T> for CUDAMetric {
formatted = format!("{} - Usage {}", formatted, utilization_rate_formatted);
}

RunningMetricResult {
Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current: String::new(),
}
})
}

fn clear(&mut self) {}
Expand Down
17 changes: 12 additions & 5 deletions src/train/metric/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::RunningMetricResult;
use crate::tensor::back::Backend;
use crate::tensor::ElementConversion;
use crate::tensor::Tensor;
use crate::train::metric::RunningMetric;
use crate::train::metric::{Metric, MetricState, NumericMetric};

pub struct LossMetric {
current: f64,
Expand All @@ -19,8 +19,15 @@ impl LossMetric {
}
}
}
impl<B: Backend> RunningMetric<Tensor<B, 1>> for LossMetric {
fn update(&mut self, loss: &Tensor<B, 1>) -> RunningMetricResult {

impl NumericMetric for LossMetric {
fn value(&self) -> f64 {
self.current * 100.0
}
}

impl<B: Backend> Metric<Tensor<B, 1>> for LossMetric {
fn update(&mut self, loss: &Tensor<B, 1>) -> Box<dyn MetricState> {
let loss = f64::from_elem(loss.to_data().value[0]);

self.count += 1;
Expand All @@ -33,12 +40,12 @@ impl<B: Backend> RunningMetric<Tensor<B, 1>> for LossMetric {
let raw_current = format!("{}", self.current);
let formatted = format!("running {:.3} current {:.3}", running, self.current);

RunningMetricResult {
Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current,
}
})
}

fn clear(&mut self) {
Expand Down
Loading

0 comments on commit 508cfd2

Please sign in to comment.