Skip to content

Commit

Permalink
Refactor/metric adaptor (tracel-ai#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 26, 2022
1 parent 567adfb commit 248039d
Show file tree
Hide file tree
Showing 14 changed files with 303 additions and 196 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: test

on: [push, pull_request]
on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
test-burn-dataset:
Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ Therefore, creating the tape only requires a simple and efficent graph traversal

To run with CUDA set `TORCH_CUDA_VERSION=cu113`.

## Note
## Notes

This crate can be use alone without the entire burn stack and with only selected backends for smaller binaries.
4 changes: 2 additions & 2 deletions burn/src/optim/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub trait Optimizer: Send + Sync {

/// Register the optimizer state for a given parameter.
///
/// # Note
/// # Notes
///
/// This should only be called by generated code.
fn register_param_state<const D: usize>(
Expand All @@ -39,7 +39,7 @@ pub trait Optimizer: Send + Sync {

/// Load the optimizer state for a given parameter.
///
/// # Note
/// # Notes
///
/// This should only be called by generated code.
fn load_param_state<const D: usize>(
Expand Down
22 changes: 17 additions & 5 deletions burn/src/train/learner/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::train::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer
use crate::train::logger::FileMetricLogger;
use crate::train::metric::dashboard::cli::CLIDashboardRenderer;
use crate::train::metric::dashboard::Dashboard;
use crate::train::metric::{Metric, Numeric};
use crate::train::metric::{Adaptor, Metric, Numeric};
use crate::train::AsyncTrainerCallback;
use burn_tensor::backend::ADBackend;
use burn_tensor::Element;
Expand Down Expand Up @@ -53,13 +53,19 @@ where
}

/// Register a training metric.
pub fn metric_train<M: Metric<T> + 'static>(mut self, metric: M) -> Self {
pub fn metric_train<M: Metric + 'static>(mut self, metric: M) -> Self
where
T: Adaptor<M::Input>,
{
self.dashboard.register_train(metric);
self
}

/// Register a validation metric.
pub fn metric_valid<M: Metric<V> + 'static>(mut self, metric: M) -> Self {
pub fn metric_valid<M: Metric + 'static>(mut self, metric: M) -> Self
where
V: Adaptor<M::Input>,
{
self.dashboard.register_valid(metric);
self
}
Expand All @@ -86,7 +92,10 @@ where
/// Only [numeric](Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
/// the same graph will be used for both.
pub fn metric_train_plot<M: Metric<T> + Numeric + 'static>(mut self, metric: M) -> Self {
pub fn metric_train_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
where
T: Adaptor<M::Input>,
{
self.dashboard.register_train_plot(metric);
self
}
Expand All @@ -98,7 +107,10 @@ where
/// Only [numeric](Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
/// the same graph will be used for both.
pub fn metric_valid_plot<M: Metric<V> + Numeric + 'static>(mut self, metric: M) -> Self {
pub fn metric_valid_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
where
V: Adaptor<M::Input>,
{
self.dashboard.register_valid_plot(metric);
self
}
Expand Down
22 changes: 8 additions & 14 deletions burn/src/train/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
use crate::tensor::backend::Backend;
use crate::train::metric;
use crate::train::metric::{AccuracyInput, Adaptor, LossInput};
use burn_tensor::Tensor;

/// Simple classification output adapted for multiple metrics.
#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub output: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 1>,
}

impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
self.update(&item.loss)
}
fn clear(&mut self) {
<metric::LossMetric as metric::Metric<Tensor<B, 1>>>::clear(self);
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> AccuracyInput<B> {
AccuracyInput::new(self.output.clone(), self.targets.clone())
}
}

impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMetric {
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
self.update(&(item.output.clone(), item.targets.clone()))
}

fn clear(&mut self) {
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self);
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
17 changes: 9 additions & 8 deletions burn/src/train/logger/metric.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::train::metric::MetricEntry;

use super::{AsyncLogger, FileLogger, Logger};
use crate::train::metric::MetricState;
use std::collections::HashMap;

pub trait MetricLogger: Send {
fn log(&mut self, item: &dyn MetricState);
fn log(&mut self, item: &MetricEntry);
fn epoch(&mut self, epoch: usize);
}

Expand All @@ -24,11 +25,11 @@ impl FileMetricLogger {
}

impl MetricLogger for FileMetricLogger {
fn log(&mut self, item: &dyn MetricState) {
let key = item.name();
let value = item.serialize();
fn log(&mut self, item: &MetricEntry) {
let key = &item.name;
let value = &item.serialize;

let logger = match self.loggers.get_mut(&key) {
let logger = match self.loggers.get_mut(key) {
Some(val) => val,
None => {
let directory = format!("{}/epoch-{}", self.directory, self.epoch);
Expand All @@ -39,11 +40,11 @@ impl MetricLogger for FileMetricLogger {
let logger = AsyncLogger::new(Box::new(logger));

self.loggers.insert(key.clone(), Box::new(logger));
self.loggers.get_mut(&key).unwrap()
self.loggers.get_mut(key).unwrap()
}
};

logger.log(value);
logger.log(value.clone());
}

fn epoch(&mut self, epoch: usize) {
Expand Down
90 changes: 38 additions & 52 deletions burn/src/train/metric/acc.rs
Original file line number Diff line number Diff line change
@@ -1,74 +1,60 @@
use super::RunningMetricResult;
use super::state::{FormatOptions, NumericMetricState};
use super::MetricEntry;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use crate::train::metric::{Metric, MetricStateDyn, Numeric};
use crate::train::metric::{Metric, Numeric};

pub struct AccuracyMetric {
current: f64,
count: usize,
total: usize,
/// The accuracy metric.
#[derive(Default)]
pub struct AccuracyMetric<B: Backend> {
state: NumericMetricState,
_b: B,
}

impl AccuracyMetric {
pub fn new() -> Self {
Self {
count: 0,
current: 0.0,
total: 0,
}
}
/// The [accuracy metric](AccuracyMetric) input type.
#[derive(new)]
pub struct AccuracyInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B::IntegerBackend, 1>,
}

impl Default for AccuracyMetric {
fn default() -> Self {
Self::new()
impl<B: Backend> AccuracyMetric<B> {
/// Create the metric.
pub fn new() -> Self {
Self::default()
}
}

impl Numeric for AccuracyMetric {
fn value(&self) -> f64 {
self.current * 100.0
}
}
impl<B: Backend> Metric for AccuracyMetric<B> {
type Input = AccuracyInput<B>;

impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn {
let (outputs, targets) = batch;
let count_current = outputs.dims()[0];
fn update(&mut self, input: &AccuracyInput<B>) -> MetricEntry {
let [batch_size, _n_classes] = input.outputs.dims();

let targets = targets.to_device(B::Device::default());
let outputs = outputs
let targets = input.targets.to_device(B::Device::default());
let outputs = input
.outputs
.argmax(1)
.to_device(B::Device::default())
.reshape([count_current]);
.reshape([batch_size]);

let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;
let accuracy = 100.0 * total_current as f64 / batch_size as f64;

self.count += count_current;
self.total += total_current;
self.current = total_current as f64 / count_current as f64;

let name = String::from("Accurracy");
let running = self.total as f64 / self.count as f64;
let raw_running = format!("{running}");
let raw_current = format!("{}", self.current);
let formatted = format!(
"running {:.2} % current {:.2} %",
100.0 * running,
100.0 * self.current
);

Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current,
})
self.state.update(
accuracy,
batch_size,
FormatOptions::new("Accuracy").unit("%").precision(2),
)
}

fn clear(&mut self) {
self.count = 0;
self.total = 0;
self.current = 0.0;
self.state.reset()
}
}

impl<B: Backend> Numeric for AccuracyMetric<B> {
fn value(&self) -> f64 {
self.state.value()
}
}
54 changes: 29 additions & 25 deletions burn/src/train/metric/base.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,42 @@
pub trait Metric<T>: Send + Sync {
fn update(&mut self, item: &T) -> MetricStateDyn;
/// Metric trait.
///
/// # Notes
///
/// Implementations should define their own input type only used by the metric.
/// This is important since some conflict may happen when the model output is adapted for each
/// metric's input type.
pub trait Metric: Send + Sync {
type Input;

/// Update the metric state and returns the current metric entry.
fn update(&mut self, item: &Self::Input) -> MetricEntry;
/// Clear the metric state.
fn clear(&mut self);
}

pub trait MetricState {
fn name(&self) -> String;
fn pretty(&self) -> String;
fn serialize(&self) -> String;
/// Adaptor are used to transform types so that they can be used by metrics.
///
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
/// registed with the [leaner buidler](burn::train::LearnerBuilder).
pub trait Adaptor<T> {
/// Adapt the type to be passed to a [metric](Metric).
fn adapt(&self) -> T;
}

/// Declare a metric to be numeric.
///
/// This is usefull to plot the values of a metric during training.
pub trait Numeric {
fn value(&self) -> f64;
}

pub type MetricStateDyn = Box<dyn MetricState>;

/// Data type that contains the current state of a metric at a given time.
#[derive(new)]
pub struct RunningMetricResult {
pub struct MetricEntry {
/// The name of the metric.
pub name: String,
/// The string to be displayed.
pub formatted: String,
pub raw_running: String,
pub raw_current: String,
}

impl MetricState for RunningMetricResult {
fn name(&self) -> String {
self.name.clone()
}

fn pretty(&self) -> String {
self.formatted.clone()
}

fn serialize(&self) -> String {
self.raw_current.clone()
}
/// The string to be saved.
pub serialize: String,
}
22 changes: 12 additions & 10 deletions burn/src/train/metric/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::RunningMetricResult;
use crate::train::metric::{Metric, MetricState};
use super::Adaptor;
use crate::train::metric::{Metric, MetricEntry};
use nvml_wrapper::Nvml;

/// Track basic cuda infos.
pub struct CUDAMetric {
nvml: Nvml,
}
Expand All @@ -20,8 +21,14 @@ impl Default for CUDAMetric {
}
}

impl<T> Metric<T> for CUDAMetric {
fn update(&mut self, _item: &T) -> Box<dyn MetricState> {
impl<T> Adaptor<()> for T {
fn adapt(&self) {}
}

impl Metric for CUDAMetric {
type Input = ();

fn update(&mut self, _item: &()) -> MetricEntry {
let name = String::from("Cuda");

let mut formatted = String::new();
Expand All @@ -44,12 +51,7 @@ impl<T> Metric<T> for CUDAMetric {
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
}

Box::new(RunningMetricResult {
name,
formatted,
raw_running,
raw_current: String::new(),
})
MetricEntry::new(name, formatted, raw_running)
}

fn clear(&mut self) {}
Expand Down
Loading

0 comments on commit 248039d

Please sign in to comment.