-
Notifications
You must be signed in to change notification settings - Fork 481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add precision classification metric #2293
Conversation
improve dummy classification input. reformat precision and add test with dummy data.
… module to lib.rs make precision a classification metric.
+1 confusion matrix |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2293 +/- ##
==========================================
- Coverage 82.93% 82.90% -0.03%
==========================================
Files 815 818 +3
Lines 105344 105603 +259
==========================================
+ Hits 87371 87555 +184
- Misses 17973 18048 +75 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
Ha, I was just discussing this with another user on discord last week 😄 Thanks for the PR, I'll try to look into it today! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, just took some time to review! Thanks again for contributing 🙏
Overall, the numerical implementation looks correct! 🙂
But I have some comments regarding the implementation.
Also, I would refactor the tests to split them into smaller individual units. I see that you wanted to reuse as much code, which is great, but in this case it also makes the tests complicated.
Each functionality should be isolated in the tests. It makes it easier to 1) understand what is going on (and what is being tested) and 2) pin point what test failed in the event of a regression.
With these changes, the ClassificationType
enum iter can probably be eliminated so the binary, multiclass and multilabel cases are separated in different test cases.
…tion input, clarify descriptions, remove dead code, rename some objects
Hey sorry it took me some time, I made some changes and continued the discussion on some of your comments. Lemme know what you think :) |
No worries 🙂 Thanks for addressing my comments, I'll take a look at the changes today! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff! Especially with the tests, much easier to understand what is being tested now 😄
I have some minor comments, but once these are addressed I think we should be good to merge!
…approx lib and use tensordata asserts, move aggregate and average functions to ConfusionStats implementation
Hey sorry I've added an extra change that I realized was missing. With just thresholds one couldn't use the metric correctly for multiclass thus I added top_k as an mutually exclusive option so it makes sense. |
No problem, will try to make time to review this today or tomorrow! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the changes!
Just a couple of comments regarding some stuff that has moved or been added in the latest commits, but we're almost there 😄
/// Sample x Class Non thresholded normalized predictions. | ||
pub predictions: Tensor<B, 2>, | ||
/// Sample x Class one-hot encoded target. | ||
pub targets: Tensor<B, 2, Bool>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With your changes in #2413 we should be able to accept targets that are not one-hot encoded (e.g., [1, 0, 2, 2, 0]
instead of [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]]
) and have it configurable for the metric.
But this can be done in a follow-up PR 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, I've been looking at the existing code in burn-trai/src/learner/classification and I think it would be easier to just use the adaptor to convert between bin/multiclass/multilable outputs to the general one hot encoded metrics, non?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah my comment was not meant to say that this is where is should be handled specifically 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah of course, I just wanted to get your thoughts on what I had thought to do. Now we have the ClassificationOutput
and MultiLabelClassificationOutput
adapted for ClassificationInput
. It works but I'm still not super happy about it since then the user is able to use, for example BinaryPrecisionMetric
with a MultiLabelClassificationOutput
. My idea for the future would be to have separated Inputs and Outputs for each of the classification types such that this would not be possible and would complain at compile time. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand where you're coming from! But I don't see a straightforward way to do this.. we would have to have a different implementation for the binary, multiclass and multilabel precision metrics because the input is an associated type for the Metric
trait.
For now I think it's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I couldn't think of a way around it but maybe it'll come to me while working on other metrics. 🤞
… make dummy data more predictable and add tests for top_k > 1
…case num_class = 1, reformat dummy data, make use of derive(new) for metric init.
Fyi @antimora not all changes have been addressed so further review will be pending until then. The metric should handle the different configurations for binary, multiclass and multi-label (previously suggested separate structs |
…th class_reduction as default and new setter implementation, move NonZerousize boundary to confusion_stats
Hey hey, sorry it took me some time, I was trying to find a way to tie the metric with the type of Output chosen when implementing the TrainStep for the Model but I wasn't successful so I just went with what I pushed now. I also found an error on the one_hot function that we pushed before. Lemme know what you think :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the changes are good overall!
Just come minor comments.
Also, what about splitting the metric to have BinaryPrecision
, MulticlassPrecision
and MultilabelPrecision
instead of users having to fiddle with the right top_k
and threshold
parameters?
…mplementation, deal with classification output with 1 class, make macro average default, expose ClassReduction type and split precision implementations by classification type
# Conflicts: # burn-book/src/building-blocks/metric.md # crates/burn-train/Cargo.toml
I've separated the Precision metric by classification type. There is a lot of repeated code but I couldn't find a better way to do it, got any advice? Lemme know what you think 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some changes to reduce the repeated code. The PrecisionMetric
now has a config and can only be created from one of
PrecisionMetric::binary(threshold)
PrecisionMetric::multiclass(top_k)
PrecisionMetric::multilabel(threshold)
I think this reduces the friction points we previously discussed.
See my comment also for the adaptor.
PrecisionInput::new( | ||
self.output.clone(), | ||
self.targets.clone().unsqueeze_dim(1).bool(), | ||
) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this also transform the targets with .one_hot(...)
but force the num classes to 2 (assuming binary classification)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, I think that would be covered by the case above since then num_classes == 2
. This, I think, is slightly different from binary classification ( I'm thinking for example classifying binary: spam email vs not spam email and multiclass with 2 labels: trees or bushes).Finally, I don't think it would work as output and targets should have the same shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spam vs not spam (i.e., positive vs negative) is still considered binary classification. I think what you are talking about is just that one uses sigmoid to model the positive output (so there is only one score) vs softmax where you have a score for each. But the targets are still represented as 0 or 1.
I'm not sure what self.targets.clone().unsqueeze_dim(1).bool()
is supposed to represent 🤔 I could be missing something though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, at the end of the day they all are the same thing. The point I was trying to make is that in binary classification tasks I expect the output of the model to be (batch_size x 1) while in multi class I would expect (batch_size x 2), thus the targets for the first would be transformed from (batch_size) -> (batch_size, 1) and the second (batch_size) -> (batch_size, 2) so they match the shapes of the outputs. Still this is a preference, if you think it doesn't make sense we can just assert that the second dim of output is not less than 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh ok nvm I thought this would lead to some issues because the targets are expected to be one-hot encoded, but in reality this is not entirely true for the binary case for a single scalar output and target. The operations performed will still be valid.
Yeah thank you for the changes it looks much better! I like the config solution. Just added pub to the PrecisionMetric, all else seems fine to me. Also, just out of curiosity, why isn't it possible to use the general ClassificationInput? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on the public 😅
Also, just out of curiosity, why isn't it possible to use the general ClassificationInput?
I simply renamed it since it only applies to the precision metric for now.
Should be good to merge, just one comment left 🙂
PrecisionInput::new( | ||
self.output.clone(), | ||
self.targets.clone().unsqueeze_dim(1).bool(), | ||
) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spam vs not spam (i.e., positive vs negative) is still considered binary classification. I think what you are talking about is just that one uses sigmoid to model the positive output (so there is only one score) vs softmax where you have a score for each. But the targets are still represented as 0 or 1.
I'm not sure what self.targets.clone().unsqueeze_dim(1).bool()
is supposed to represent 🤔 I could be missing something though
PrecisionInput::new( | ||
self.output.clone(), | ||
self.targets.clone().unsqueeze_dim(1).bool(), | ||
) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh ok nvm I thought this would lead to some issues because the targets are expected to be one-hot encoded, but in reality this is not entirely true for the binary case for a single scalar output and target. The operations performed will still be valid.
Checklist
run-checks all
script has been executed.Related Issues/PRs
#544
Changes
Hello, so I wanted to add precision to the metrics and ended up realizing that implementing a confusion matrix would not only help with precision but also other classification metrics. So I implemented it also along with some code to deal with classification averaging (micro, macro) and thresholds .
I also sneaked in some code to generate some dummy data for testing the metrics. I'm not sure it is decent enough but open for suggestions.
I do realize that it is a long PR, if necessary I can split it into different ones.
I'm new to rust dev and come from python so lemme know if I need to adjust some patterns.
Testing
Like said above I added some dummy classification data generator and tested confusion matrix and precision metric's methods with it.