-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add precision classification metric #2293
Merged
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
c73438f
Implement confusion matrix and precision, first draft
63f4a1d
Implement confusion matrix
b9d71b6
format :D
eac29aa
add agg type to cm, reformat debug representation add testing.
59db68b
formating and tiny refactor
4261bd8
add ClassificationMetric trait, rename variables and types, move test…
5431a2f
change unwrap to expect
fd2e585
update book
56965e8
remove unused code
419438a
changes to make reusing code easier
dfac847
format :D
ea4b29c
change to static data tests
e23aa7b
remove classification metric trait, add auxiliary code for classific…
60a246b
move classification objects to classification.rs, use rstest, remove …
c145531
review docstring, add top_k for multiclass tasks.
0c984c4
move class averaging and metric computation to metric implementation,…
b0a2939
change struct and var names
f18e321
Merge branch 'main' into add-to-metrics
386802c
rename params, enforce nonzero for top_k param, optimize one_hot for …
b525527
add adaptor por classification input, correct one hot function
ff7611a
define default for ClassReduction, derive new for Precision metric wi…
4cbcff2
Merge branch 'main' into add-to-metrics
eeab0d3
expose PrecisionMetric, change metric initialization
aea207f
check one_hot input tensor has more than 1 classes and correct it's i…
410f273
Merge branch 'main' into add-to-metrics
746fa9d
implement adaptor for MultilabelClassificationOutput and Classificati…
7428b86
change with_top_k to take usize
58e1902
Merge branch 'main' into add-to-metrics
d598f00
Add precision config for binary, multiclass and multilabel
laggui 1542ee9
Fix dummy_classification_input
laggui 03ebe1d
make PrecisionMetric public
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
move classification objects to classification.rs, use rstest, remove …
…approx lib and use tensordata asserts, move aggregate and average functions to ConfusionStats implementation
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
use burn_core::prelude::{Backend, Bool, Tensor}; | ||
|
||
/// The [classification metric](ClassificationMetric) input type. | ||
#[derive(new, Debug, Clone)] | ||
pub struct ClassificationInput<B: Backend> { | ||
/// Sample x Class Non thresholded normalized predictions. | ||
pub predictions: Tensor<B, 2>, | ||
/// Sample x Class one-hot encoded target. | ||
pub targets: Tensor<B, 2, Bool>, | ||
} | ||
|
||
impl<B: Backend> From<ClassificationInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) { | ||
fn from(val: ClassificationInput<B>) -> Self { | ||
(val.predictions, val.targets) | ||
} | ||
} | ||
|
||
/// Class Averaging types for Classification metrics. | ||
#[derive(Copy, Clone)] | ||
#[allow(dead_code)] | ||
pub enum ClassAverageType { | ||
laggui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
///Computes the statistics over all classes before averaging | ||
Micro, | ||
///Computes the statistics independently for each class before averaging | ||
Macro, | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With your changes in #2413 we should be able to accept targets that are not one-hot encoded (e.g.,
[1, 0, 2, 2, 0]
instead of[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]]
) and have it configurable for the metric.But this can be done in a follow-up PR 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, I've been looking at the existing code in burn-trai/src/learner/classification and I think it would be easier to just use the adaptor to convert between bin/multiclass/multilable outputs to the general one hot encoded metrics, non?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah my comment was not meant to say that this is where is should be handled specifically 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah of course, I just wanted to get your thoughts on what I had thought to do. Now we have the
ClassificationOutput
andMultiLabelClassificationOutput
adapted forClassificationInput
. It works but I'm still not super happy about it since then the user is able to use, for exampleBinaryPrecisionMetric
with aMultiLabelClassificationOutput
. My idea for the future would be to have separated Inputs and Outputs for each of the classification types such that this would not be possible and would complain at compile time. Thoughts?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand where you're coming from! But I don't see a straightforward way to do this.. we would have to have a different implementation for the binary, multiclass and multilabel precision metrics because the input is an associated type for the
Metric
trait.For now I think it's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I couldn't think of a way around it but maybe it'll come to me while working on other metrics. 🤞