forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor/metric adaptor (tracel-ai#139)
- Loading branch information
1 parent
567adfb
commit 248039d
Showing
14 changed files
with
303 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,23 @@ | ||
use crate::tensor::backend::Backend; | ||
use crate::train::metric; | ||
use crate::train::metric::{AccuracyInput, Adaptor, LossInput}; | ||
use burn_tensor::Tensor; | ||
|
||
/// Simple classification output adapted for multiple metrics. | ||
#[derive(new)] | ||
pub struct ClassificationOutput<B: Backend> { | ||
pub loss: Tensor<B, 1>, | ||
pub output: Tensor<B, 2>, | ||
pub targets: Tensor<B::IntegerBackend, 1>, | ||
} | ||
|
||
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric { | ||
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn { | ||
self.update(&item.loss) | ||
} | ||
fn clear(&mut self) { | ||
<metric::LossMetric as metric::Metric<Tensor<B, 1>>>::clear(self); | ||
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> { | ||
fn adapt(&self) -> AccuracyInput<B> { | ||
AccuracyInput::new(self.output.clone(), self.targets.clone()) | ||
} | ||
} | ||
|
||
impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMetric { | ||
fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn { | ||
self.update(&(item.output.clone(), item.targets.clone())) | ||
} | ||
|
||
fn clear(&mut self) { | ||
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self); | ||
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> { | ||
fn adapt(&self) -> LossInput<B> { | ||
LossInput::new(self.loss.clone()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,60 @@ | ||
use super::RunningMetricResult; | ||
use super::state::{FormatOptions, NumericMetricState}; | ||
use super::MetricEntry; | ||
use crate::tensor::backend::Backend; | ||
use crate::tensor::Tensor; | ||
use crate::train::metric::{Metric, MetricStateDyn, Numeric}; | ||
use crate::train::metric::{Metric, Numeric}; | ||
|
||
pub struct AccuracyMetric { | ||
current: f64, | ||
count: usize, | ||
total: usize, | ||
/// The accuracy metric. | ||
#[derive(Default)] | ||
pub struct AccuracyMetric<B: Backend> { | ||
state: NumericMetricState, | ||
_b: B, | ||
} | ||
|
||
impl AccuracyMetric { | ||
pub fn new() -> Self { | ||
Self { | ||
count: 0, | ||
current: 0.0, | ||
total: 0, | ||
} | ||
} | ||
/// The [accuracy metric](AccuracyMetric) input type. | ||
#[derive(new)] | ||
pub struct AccuracyInput<B: Backend> { | ||
outputs: Tensor<B, 2>, | ||
targets: Tensor<B::IntegerBackend, 1>, | ||
} | ||
|
||
impl Default for AccuracyMetric { | ||
fn default() -> Self { | ||
Self::new() | ||
impl<B: Backend> AccuracyMetric<B> { | ||
/// Create the metric. | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
} | ||
|
||
impl Numeric for AccuracyMetric { | ||
fn value(&self) -> f64 { | ||
self.current * 100.0 | ||
} | ||
} | ||
impl<B: Backend> Metric for AccuracyMetric<B> { | ||
type Input = AccuracyInput<B>; | ||
|
||
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric { | ||
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn { | ||
let (outputs, targets) = batch; | ||
let count_current = outputs.dims()[0]; | ||
fn update(&mut self, input: &AccuracyInput<B>) -> MetricEntry { | ||
let [batch_size, _n_classes] = input.outputs.dims(); | ||
|
||
let targets = targets.to_device(B::Device::default()); | ||
let outputs = outputs | ||
let targets = input.targets.to_device(B::Device::default()); | ||
let outputs = input | ||
.outputs | ||
.argmax(1) | ||
.to_device(B::Device::default()) | ||
.reshape([count_current]); | ||
.reshape([batch_size]); | ||
|
||
let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize; | ||
let accuracy = 100.0 * total_current as f64 / batch_size as f64; | ||
|
||
self.count += count_current; | ||
self.total += total_current; | ||
self.current = total_current as f64 / count_current as f64; | ||
|
||
let name = String::from("Accurracy"); | ||
let running = self.total as f64 / self.count as f64; | ||
let raw_running = format!("{running}"); | ||
let raw_current = format!("{}", self.current); | ||
let formatted = format!( | ||
"running {:.2} % current {:.2} %", | ||
100.0 * running, | ||
100.0 * self.current | ||
); | ||
|
||
Box::new(RunningMetricResult { | ||
name, | ||
formatted, | ||
raw_running, | ||
raw_current, | ||
}) | ||
self.state.update( | ||
accuracy, | ||
batch_size, | ||
FormatOptions::new("Accuracy").unit("%").precision(2), | ||
) | ||
} | ||
|
||
fn clear(&mut self) { | ||
self.count = 0; | ||
self.total = 0; | ||
self.current = 0.0; | ||
self.state.reset() | ||
} | ||
} | ||
|
||
impl<B: Backend> Numeric for AccuracyMetric<B> { | ||
fn value(&self) -> f64 { | ||
self.state.value() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,42 @@ | ||
pub trait Metric<T>: Send + Sync { | ||
fn update(&mut self, item: &T) -> MetricStateDyn; | ||
/// Metric trait. | ||
/// | ||
/// # Notes | ||
/// | ||
/// Implementations should define their own input type only used by the metric. | ||
/// This is important since some conflict may happen when the model output is adapted for each | ||
/// metric's input type. | ||
pub trait Metric: Send + Sync { | ||
type Input; | ||
|
||
/// Update the metric state and returns the current metric entry. | ||
fn update(&mut self, item: &Self::Input) -> MetricEntry; | ||
/// Clear the metric state. | ||
fn clear(&mut self); | ||
} | ||
|
||
pub trait MetricState { | ||
fn name(&self) -> String; | ||
fn pretty(&self) -> String; | ||
fn serialize(&self) -> String; | ||
/// Adaptor are used to transform types so that they can be used by metrics. | ||
/// | ||
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are | ||
/// registed with the [leaner buidler](burn::train::LearnerBuilder). | ||
pub trait Adaptor<T> { | ||
/// Adapt the type to be passed to a [metric](Metric). | ||
fn adapt(&self) -> T; | ||
} | ||
|
||
/// Declare a metric to be numeric. | ||
/// | ||
/// This is usefull to plot the values of a metric during training. | ||
pub trait Numeric { | ||
fn value(&self) -> f64; | ||
} | ||
|
||
pub type MetricStateDyn = Box<dyn MetricState>; | ||
|
||
/// Data type that contains the current state of a metric at a given time. | ||
#[derive(new)] | ||
pub struct RunningMetricResult { | ||
pub struct MetricEntry { | ||
/// The name of the metric. | ||
pub name: String, | ||
/// The string to be displayed. | ||
pub formatted: String, | ||
pub raw_running: String, | ||
pub raw_current: String, | ||
} | ||
|
||
impl MetricState for RunningMetricResult { | ||
fn name(&self) -> String { | ||
self.name.clone() | ||
} | ||
|
||
fn pretty(&self) -> String { | ||
self.formatted.clone() | ||
} | ||
|
||
fn serialize(&self) -> String { | ||
self.raw_current.clone() | ||
} | ||
/// The string to be saved. | ||
pub serialize: String, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.