Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add precision classification metric #2293

Merged
merged 31 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c73438f
Implement confusion matrix and precision, first draft
Aug 21, 2024
63f4a1d
Implement confusion matrix
Sep 9, 2024
b9d71b6
format :D
Sep 9, 2024
eac29aa
add agg type to cm, reformat debug representation add testing.
Sep 20, 2024
59db68b
formating and tiny refactor
Sep 21, 2024
4261bd8
add ClassificationMetric trait, rename variables and types, move test…
Sep 21, 2024
5431a2f
change unwrap to expect
Sep 21, 2024
fd2e585
update book
Sep 21, 2024
56965e8
remove unused code
Sep 22, 2024
419438a
changes to make reusing code easier
Sep 22, 2024
dfac847
format :D
Sep 22, 2024
ea4b29c
change to static data tests
Sep 24, 2024
e23aa7b
remove classification metric trait, add auxiliary code for classific…
Oct 14, 2024
60a246b
move classification objects to classification.rs, use rstest, remove …
Oct 21, 2024
c145531
review docstring, add top_k for multiclass tasks.
Oct 23, 2024
0c984c4
move class averaging and metric computation to metric implementation,…
Oct 25, 2024
b0a2939
change struct and var names
Oct 25, 2024
f18e321
Merge branch 'main' into add-to-metrics
Oct 26, 2024
386802c
rename params, enforce nonzero for top_k param, optimize one_hot for …
Oct 30, 2024
b525527
add adaptor por classification input, correct one hot function
Nov 1, 2024
ff7611a
define default for ClassReduction, derive new for Precision metric wi…
Nov 8, 2024
4cbcff2
Merge branch 'main' into add-to-metrics
Nov 8, 2024
eeab0d3
expose PrecisionMetric, change metric initialization
Nov 8, 2024
aea207f
check one_hot input tensor has more than 1 classes and correct it's i…
Nov 16, 2024
410f273
Merge branch 'main' into add-to-metrics
Nov 16, 2024
746fa9d
implement adaptor for MultilabelClassificationOutput and Classificati…
Nov 16, 2024
7428b86
change with_top_k to take usize
Nov 18, 2024
58e1902
Merge branch 'main' into add-to-metrics
Nov 18, 2024
d598f00
Add precision config for binary, multiclass and multilabel
laggui Nov 18, 2024
1542ee9
Fix dummy_classification_input
laggui Nov 18, 2024
03ebe1d
make PrecisionMetric public
Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add ClassificationMetric trait, rename variables and types, move test…
… module to lib.rs make precision a classification metric.
  • Loading branch information
Tiago Sanona committed Sep 21, 2024
commit 4261bd88f0b29cefa65024819529b7586672482e
2 changes: 1 addition & 1 deletion crates/burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ crossterm = { workspace = true, optional = true }
derive-new = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
strum = { workspace = true }
rand.workspace = true


[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0" }
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
approx = "0.5.1"
laggui marked this conversation as resolved.
Show resolved Hide resolved

[package.metadata.docs.rs]
Expand Down
153 changes: 153 additions & 0 deletions crates/burn-train/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,156 @@ pub use learner::*;

#[cfg(test)]
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;

#[cfg(test)]
pub(crate) mod tests {
use crate::metric::ClassificationInput;
use crate::TestBackend;
use burn_core::prelude::Tensor;
use burn_core::tensor::{Distribution, Shape};
use rand::seq::IteratorRandom;
use std::default::Default;
use strum::EnumIter;

/// Probability of tp before adding errors
pub const TRUE_POSITIVE_RATE: f64 = 0.5;
pub const THRESHOLD: f64 = 0.5;
pub const ERROR_PER_SAMPLE_RATE: f32 = 0.2;

#[derive(EnumIter, Debug)]
pub enum ClassificationType {
Binary,
Multiclass,
Multilabel,
}

fn one_hot_encode(
class_tensor: Tensor<TestBackend, 2>,
n_classes: usize,
) -> Tensor<TestBackend, 2> {
Tensor::stack(
class_tensor
.to_data()
.iter()
.map(|class_index: f32| {
Tensor::<TestBackend, 1>::one_hot(
class_index as usize,
n_classes,
&class_tensor.device(),
)
})
.collect(),
0,
)
}

/// Sample x Class shaped matrix for use in
/// classification metrics testing
pub fn dummy_classification_input(
classification_type: &ClassificationType,
) -> (ClassificationInput<TestBackend>, Tensor<TestBackend, 2>) {
let device = &Default::default();
const N_SAMPLES: usize = 200;
const N_CLASSES: usize = 4;

let error_mask = {
let mut rng = &mut rand::thread_rng();
let change_idx = Tensor::from_floats(
(0..N_SAMPLES)
.into_iter()
.choose_multiple(
&mut rng,
(N_SAMPLES as f32 * ERROR_PER_SAMPLE_RATE) as usize,
)
.as_slice(),
device,
);
let values = change_idx.ones_like();
let mask =
Tensor::zeros(Shape::new([N_SAMPLES]), device).scatter(0, change_idx.int(), values);
mask.unsqueeze_dim(1).bool()
};

let (targets, changed_targets) = match classification_type {
ClassificationType::Binary => {
let targets = Tensor::<TestBackend, 2>::random(
Shape::new([N_SAMPLES, 1]),
Distribution::Bernoulli(TRUE_POSITIVE_RATE),
device,
)
.bool();

let changed_targets = targets.clone().not_equal(error_mask);
(targets, changed_targets.float())
}
ClassificationType::Multiclass => {
let mut classes_changes =
Tensor::<TestBackend, 2>::random([N_SAMPLES, 2], Distribution::Default, device)
.mul_scalar(4)
.int()
.float()
.chunk(2, 1);
let (classes, changes) = (classes_changes.remove(0), classes_changes.remove(0));
let changed_classes = (classes.clone()
+ changes.clamp(1, 2) * error_mask.clone().float())
% N_CLASSES as f32;
(
one_hot_encode(classes, N_CLASSES).bool(),
one_hot_encode(changed_classes, N_CLASSES),
)
}
ClassificationType::Multilabel => {
let targets = Tensor::<TestBackend, 2>::random(
[N_SAMPLES, N_CLASSES - 1],
Distribution::Default,
device,
)
.greater_elem(THRESHOLD);

(targets.clone(), (targets.float() + error_mask.float()) % 2)
}
};
let predictions = changed_targets
.random_like(Distribution::Uniform(0.0, THRESHOLD - 0.1))
.sub(changed_targets.clone())
.abs();

(
ClassificationInput::new(predictions, targets.clone()),
targets.float().sub(changed_targets),
)
}
laggui marked this conversation as resolved.
Show resolved Hide resolved

use burn_core::tensor::cast::ToElement;
use burn_core::tensor::TensorData;
use strum::IntoEnumIterator;

#[test]
fn test_predictions_targets_match() {
laggui marked this conversation as resolved.
Show resolved Hide resolved
for classification_type in ClassificationType::iter() {
let (input, target_diff) = dummy_classification_input(&classification_type);
let thresholded_prediction = input.predictions.clone().greater_elem(THRESHOLD);
TensorData::assert_eq(
&(target_diff.clone() + thresholded_prediction.float()).to_data(),
&input.targets.clone().float().to_data(),
true,
);
}
}

#[test]
fn test_error_rate() {
laggui marked this conversation as resolved.
Show resolved Hide resolved
for classification_type in ClassificationType::iter() {
let (_, target_diff) = dummy_classification_input(&classification_type);
let mean_difference_targets = target_diff
.abs()
.bool()
.any_dim(1)
.float()
.mean()
.into_scalar()
.to_f32();
assert_eq!(mean_difference_targets, ERROR_PER_SAMPLE_RATE,);
}
}
}
114 changes: 69 additions & 45 deletions crates/burn-train/src/metric/base.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,8 @@
use burn_core::prelude::{Backend, Bool, Tensor};
use burn_core::tensor::cast::ToElement;
use burn_core::tensor::Int;
use burn_core::{data::dataloader::Progress, LearningRate};
use strum::EnumIter;

///Aggregation types for Classification metrics
#[derive(EnumIter, Copy, Clone, Debug)]
pub enum AggregationType {
Micro,
Macro,
//Weighted(Box<[f64]>), todo!()
}

impl AggregationType {
pub fn aggregate<B: Backend>(self, cm_mask: Tensor<B, 2, Int>) -> Tensor<B, 1> {
match self {
AggregationType::Macro => cm_mask.sum_dim(0).squeeze(0).float(),
AggregationType::Micro => cm_mask.sum().float(), //MetricAverage::Weighted(weights) => Left(metric.float().sum_dim(0).squeeze(0) * Tensor::from_floats(weights.deref(), &B::Device::default())) todo!()
}
}

pub fn aggregate_mean<B: Backend>(self, cm_mask: Tensor<B, 2, Int>) -> Tensor<B, 1> {
match self {
AggregationType::Macro => cm_mask.float().mean_dim(0).squeeze(0),
AggregationType::Micro => cm_mask.float().mean(), //MetricAverage::Weighted(weights) => Left(metric.float().sum_dim(0).squeeze(0) * Tensor::from_floats(weights.deref(), &B::Device::default())) todo!()
}
}

pub fn to_averaged_tensor<B: Backend>(self, mut metrics: Tensor<B, 1>) -> Tensor<B, 1> {
match self {
AggregationType::Macro => {
if metrics.contains_nan().any().into_scalar() {
let nan_mask = metrics.is_nan();
metrics = metrics
.clone()
.select(0, nan_mask.bool_not().argwhere().squeeze(1))
}
metrics.mean()
}
AggregationType::Micro => metrics,
//MetricAverage::Weighted(weights) =>
}
}

pub fn to_averaged_metric<B: Backend>(self, metrics: Tensor<B, 1>) -> f64 {
self.to_averaged_tensor(metrics).into_scalar().to_f64()
}
}

/// Metric metadata that can be used when computing metrics.
pub struct MetricMetadata {
/// The current progress.
Expand Down Expand Up @@ -104,6 +59,16 @@ pub trait Metric: Send + Sync {
fn clear(&mut self);
}

/// Classification Metric trait
///
/// Requires implementation [Metric](Metric)<Input = ClassificationInput<B>>
pub trait ClassificationMetric<B: Backend>: Metric<Input = ClassificationInput<B>> {
/// Sets threshold. Default 0.5
fn with_threshold(self, threshold: f64) -> Self;
/// Sets average type. Default Micro
fn with_average(self, average: ClassificationAverage) -> Self;
}
laggui marked this conversation as resolved.
Show resolved Hide resolved

/// The [classification metric](ClassificationMetric) input type.
#[derive(new, Debug)]
pub struct ClassificationInput<B: Backend> {
Expand All @@ -113,6 +78,65 @@ pub struct ClassificationInput<B: Backend> {
pub targets: Tensor<B, 2, Bool>,
}

///Aggregation types for Classification metric average
#[derive(EnumIter, Copy, Clone, Debug)]
laggui marked this conversation as resolved.
Show resolved Hide resolved
pub enum ClassificationAverage {
/// overall aggregation
laggui marked this conversation as resolved.
Show resolved Hide resolved
Micro,
/// over class aggregation
laggui marked this conversation as resolved.
Show resolved Hide resolved
Macro,
// /// over class aggregation, weighted average
//Weighted(Box<[f64]>), todo!()
laggui marked this conversation as resolved.
Show resolved Hide resolved
}

impl ClassificationAverage {
/// sum over samples
pub fn aggregate_sum<B: Backend>(self, sample_class_mask: Tensor<B, 2, Bool>) -> Tensor<B, 1> {
use ClassificationAverage::*;
match self {
Macro => sample_class_mask.float().sum_dim(0).squeeze(0),
Micro => sample_class_mask.float().sum(), //Weighted(weights) => Left(metric.float().sum_dim(0).squeeze(0) * Tensor::from_floats(weights.deref(), &B::Device::default())) todo!()
}
}

/// average over samples
pub fn aggregate_mean<B: Backend>(self, sample_class_mask: Tensor<B, 2, Bool>) -> Tensor<B, 1> {
use ClassificationAverage::*;
match self {
Macro => sample_class_mask.float().mean_dim(0).squeeze(0),
Micro => sample_class_mask.float().mean(), //Weighted(weights) => Left(metric.float().sum_dim(0).squeeze(0) * Tensor::from_floats(weights.deref(), &B::Device::default())) todo!()
}
}

///convert to averaged metric, returns tensor
pub fn to_averaged_tensor<B: Backend>(
self,
mut aggregated_metric: Tensor<B, 1>,
) -> Tensor<B, 1> {
use ClassificationAverage::*;
match self {
Macro => {
if aggregated_metric.contains_nan().any().into_scalar() {
let nan_mask = aggregated_metric.is_nan();
aggregated_metric = aggregated_metric
.clone()
.select(0, nan_mask.bool_not().argwhere().squeeze(1))
}
aggregated_metric.mean()
}
Micro => aggregated_metric,
//Weighted(weights) => todo!()
}
}

///convert to averaged metric, returns float
pub fn to_averaged_metric<B: Backend>(self, aggregated_metric: Tensor<B, 1>) -> f64 {
self.to_averaged_tensor(aggregated_metric)
.into_scalar()
.to_f64()
}
}
laggui marked this conversation as resolved.
Show resolved Hide resolved

/// Adaptor are used to transform types so that they can be used by metrics.
///
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
Expand Down
Loading