-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add precision classification metric #2293
Merged
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
c73438f
Implement confusion matrix and precision, first draft
63f4a1d
Implement confusion matrix
b9d71b6
format :D
eac29aa
add agg type to cm, reformat debug representation add testing.
59db68b
formating and tiny refactor
4261bd8
add ClassificationMetric trait, rename variables and types, move test…
5431a2f
change unwrap to expect
fd2e585
update book
56965e8
remove unused code
419438a
changes to make reusing code easier
dfac847
format :D
ea4b29c
change to static data tests
e23aa7b
remove classification metric trait, add auxiliary code for classific…
60a246b
move classification objects to classification.rs, use rstest, remove …
c145531
review docstring, add top_k for multiclass tasks.
0c984c4
move class averaging and metric computation to metric implementation,…
b0a2939
change struct and var names
f18e321
Merge branch 'main' into add-to-metrics
386802c
rename params, enforce nonzero for top_k param, optimize one_hot for …
b525527
add adaptor por classification input, correct one hot function
ff7611a
define default for ClassReduction, derive new for Precision metric wi…
4cbcff2
Merge branch 'main' into add-to-metrics
eeab0d3
expose PrecisionMetric, change metric initialization
aea207f
check one_hot input tensor has more than 1 classes and correct it's i…
410f273
Merge branch 'main' into add-to-metrics
746fa9d
implement adaptor for MultilabelClassificationOutput and Classificati…
7428b86
change with_top_k to take usize
58e1902
Merge branch 'main' into add-to-metrics
d598f00
Add precision config for binary, multiclass and multilabel
laggui 1542ee9
Fix dummy_classification_input
laggui 03ebe1d
make PrecisionMetric public
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/// The reduction strategy for classification metrics. | ||
#[derive(Copy, Clone, Default)] | ||
pub enum ClassReduction { | ||
/// Computes the statistics over all classes before averaging | ||
Micro, | ||
/// Computes the statistics independently for each class before averaging | ||
#[default] | ||
Macro, | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this also transform the targets with
.one_hot(...)
but force the num classes to 2 (assuming binary classification)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, I think that would be covered by the case above since then
num_classes == 2
. This, I think, is slightly different from binary classification ( I'm thinking for example classifying binary: spam email vs not spam email and multiclass with 2 labels: trees or bushes).Finally, I don't think it would work as output and targets should have the same shape.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spam vs not spam (i.e., positive vs negative) is still considered binary classification. I think what you are talking about is just that one uses sigmoid to model the positive output (so there is only one score) vs softmax where you have a score for each. But the targets are still represented as 0 or 1.
I'm not sure what
self.targets.clone().unsqueeze_dim(1).bool()
is supposed to represent 🤔 I could be missing something thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, at the end of the day they all are the same thing. The point I was trying to make is that in binary classification tasks I expect the output of the model to be (batch_size x 1) while in multi class I would expect (batch_size x 2), thus the targets for the first would be transformed from (batch_size) -> (batch_size, 1) and the second (batch_size) -> (batch_size, 2) so they match the shapes of the outputs. Still this is a preference, if you think it doesn't make sense we can just assert that the second dim of output is not less than 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh ok nvm I thought this would lead to some issues because the targets are expected to be one-hot encoded, but in reality this is not entirely true for the binary case for a single scalar output and target. The operations performed will still be valid.