Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add precision classification metric #2293

Merged
merged 31 commits into from
Nov 20, 2024
Merged

Conversation

tsanona
Copy link
Contributor

@tsanona tsanona commented Sep 21, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#544

Changes

Hello, so I wanted to add precision to the metrics and ended up realizing that implementing a confusion matrix would not only help with precision but also other classification metrics. So I implemented it also along with some code to deal with classification averaging (micro, macro) and thresholds .
I also sneaked in some code to generate some dummy data for testing the metrics. I'm not sure it is decent enough but open for suggestions.
I do realize that it is a long PR, if necessary I can split it into different ones.
I'm new to rust dev and come from python so lemme know if I need to adjust some patterns.

Testing

Like said above I added some dummy classification data generator and tested confusion matrix and precision metric's methods with it.

@antimora
Copy link
Collaborator

antimora commented Sep 22, 2024

+1 confusion matrix

Copy link

codecov bot commented Sep 22, 2024

Codecov Report

Attention: Patch coverage is 77.47748% with 50 lines in your changes missing coverage. Please review.

Project coverage is 82.90%. Comparing base (6d105ea) to head (03ebe1d).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-train/src/metric/confusion_stats.rs 71.60% 23 Missing ⚠️
crates/burn-train/src/learner/classification.rs 0.00% 15 Missing ⚠️
crates/burn-train/src/metric/precision.rs 90.41% 7 Missing ⚠️
crates/burn-tensor/src/tensor/api/check.rs 0.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2293      +/-   ##
==========================================
- Coverage   82.93%   82.90%   -0.03%     
==========================================
  Files         815      818       +3     
  Lines      105344   105603     +259     
==========================================
+ Hits        87371    87555     +184     
- Misses      17973    18048      +75     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@laggui
Copy link
Member

laggui commented Sep 23, 2024

Ha, I was just discussing this with another user on discord last week 😄

Thanks for the PR, I'll try to look into it today!

@laggui laggui changed the title Add to metrics Add precision classification metric with confusion matrix Sep 23, 2024
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, just took some time to review! Thanks again for contributing 🙏

Overall, the numerical implementation looks correct! 🙂

But I have some comments regarding the implementation.

Also, I would refactor the tests to split them into smaller individual units. I see that you wanted to reuse as much code, which is great, but in this case it also makes the tests complicated.

Each functionality should be isolated in the tests. It makes it easier to 1) understand what is going on (and what is being tested) and 2) pin point what test failed in the event of a regression.

With these changes, the ClassificationType enum iter can probably be eliminated so the binary, multiclass and multilabel cases are separated in different test cases.

crates/burn-train/src/lib.rs Outdated Show resolved Hide resolved
crates/burn-train/src/lib.rs Outdated Show resolved Hide resolved
crates/burn-train/src/lib.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/confusion_matrix.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/confusion_matrix.rs Outdated Show resolved Hide resolved
@laggui laggui changed the title Add precision classification metric with confusion matrix Add precision classification metric Sep 24, 2024
Tiago Sanona added 2 commits September 24, 2024 20:15
…tion input, clarify descriptions, remove dead code, rename some objects
@tsanona
Copy link
Contributor Author

tsanona commented Oct 14, 2024

Hey sorry it took me some time, I made some changes and continued the discussion on some of your comments. Lemme know what you think :)

@tsanona tsanona requested a review from laggui October 14, 2024 14:53
@laggui
Copy link
Member

laggui commented Oct 16, 2024

Hey sorry it took me some time, I made some changes and continued the discussion on some of your comments. Lemme know what you think :)

No worries 🙂 Thanks for addressing my comments, I'll take a look at the changes today!

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff! Especially with the tests, much easier to understand what is being tested now 😄

I have some minor comments, but once these are addressed I think we should be good to merge!

crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/base.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/confusion_matrix.rs Outdated Show resolved Hide resolved
crates/burn-train/Cargo.toml Outdated Show resolved Hide resolved
crates/burn-train/src/metric/confusion_stats.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/precision.rs Outdated Show resolved Hide resolved
crates/burn-train/Cargo.toml Outdated Show resolved Hide resolved
…approx lib and use tensordata asserts, move aggregate and average functions to ConfusionStats implementation
@tsanona tsanona requested a review from laggui October 21, 2024 17:36
@tsanona
Copy link
Contributor Author

tsanona commented Oct 23, 2024

Hey sorry I've added an extra change that I realized was missing. With just thresholds one couldn't use the metric correctly for multiclass thus I added top_k as an mutually exclusive option so it makes sense.

@laggui
Copy link
Member

laggui commented Oct 23, 2024

No problem, will try to make time to review this today or tomorrow!

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the changes!

Just a couple of comments regarding some stuff that has moved or been added in the latest commits, but we're almost there 😄

crates/burn-train/src/metric/confusion_stats.rs Outdated Show resolved Hide resolved
/// Sample x Class Non thresholded normalized predictions.
pub predictions: Tensor<B, 2>,
/// Sample x Class one-hot encoded target.
pub targets: Tensor<B, 2, Bool>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With your changes in #2413 we should be able to accept targets that are not one-hot encoded (e.g., [1, 0, 2, 2, 0] instead of [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]]) and have it configurable for the metric.

But this can be done in a follow-up PR 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, I've been looking at the existing code in burn-trai/src/learner/classification and I think it would be easier to just use the adaptor to convert between bin/multiclass/multilable outputs to the general one hot encoded metrics, non?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my comment was not meant to say that this is where is should be handled specifically 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah of course, I just wanted to get your thoughts on what I had thought to do. Now we have the ClassificationOutput and MultiLabelClassificationOutput adapted for ClassificationInput. It works but I'm still not super happy about it since then the user is able to use, for example BinaryPrecisionMetric with a MultiLabelClassificationOutput. My idea for the future would be to have separated Inputs and Outputs for each of the classification types such that this would not be possible and would complain at compile time. Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand where you're coming from! But I don't see a straightforward way to do this.. we would have to have a different implementation for the binary, multiclass and multilabel precision metrics because the input is an associated type for the Metric trait.

For now I think it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I couldn't think of a way around it but maybe it'll come to me while working on other metrics. 🤞

crates/burn-train/src/metric/classification.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/confusion_stats.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/precision.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/precision.rs Outdated Show resolved Hide resolved
@antimora antimora requested a review from laggui October 29, 2024 19:18
Tiago Sanona added 2 commits October 30, 2024 02:18
@laggui
Copy link
Member

laggui commented Nov 4, 2024

Fyi @antimora not all changes have been addressed so further review will be pending until then.

The metric should handle the different configurations for binary, multiclass and multi-label (previously suggested separate structs BinaryPrecision, MulticlassPrecision and MultilabelPrecision).

Tiago Sanona added 3 commits November 8, 2024 06:09
@tsanona
Copy link
Contributor Author

tsanona commented Nov 8, 2024

Hey hey, sorry it took me some time, I was trying to find a way to tie the metric with the type of Output chosen when implementing the TrainStep for the Model but I wasn't successful so I just went with what I pushed now. I also found an error on the one_hot function that we pushed before. Lemme know what you think :)

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the changes are good overall!

Just come minor comments.

Also, what about splitting the metric to have BinaryPrecision, MulticlassPrecision and MultilabelPrecision instead of users having to fiddle with the right top_k and threshold parameters?

crates/burn-tensor/src/tensor/api/int.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/precision.rs Outdated Show resolved Hide resolved
crates/burn-train/src/metric/precision.rs Outdated Show resolved Hide resolved
Tiago Sanona added 2 commits November 16, 2024 03:48
…mplementation, deal with classification output with 1 class, make macro average default, expose ClassReduction type and split precision implementations by classification type
# Conflicts:
#	burn-book/src/building-blocks/metric.md
#	crates/burn-train/Cargo.toml
@tsanona
Copy link
Contributor Author

tsanona commented Nov 16, 2024

I've separated the Precision metric by classification type. There is a lot of repeated code but I couldn't find a better way to do it, got any advice? Lemme know what you think 🚀

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tsanona

I made some changes to reduce the repeated code. The PrecisionMetric now has a config and can only be created from one of

  • PrecisionMetric::binary(threshold)
  • PrecisionMetric::multiclass(top_k)
  • PrecisionMetric::multilabel(threshold)

I think this reduces the friction points we previously discussed.

See my comment also for the adaptor.

Comment on lines +39 to +43
PrecisionInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also transform the targets with .one_hot(...) but force the num classes to 2 (assuming binary classification)?

Copy link
Contributor Author

@tsanona tsanona Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, I think that would be covered by the case above since then num_classes == 2. This, I think, is slightly different from binary classification ( I'm thinking for example classifying binary: spam email vs not spam email and multiclass with 2 labels: trees or bushes).Finally, I don't think it would work as output and targets should have the same shape.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spam vs not spam (i.e., positive vs negative) is still considered binary classification. I think what you are talking about is just that one uses sigmoid to model the positive output (so there is only one score) vs softmax where you have a score for each. But the targets are still represented as 0 or 1.

I'm not sure what self.targets.clone().unsqueeze_dim(1).bool() is supposed to represent 🤔 I could be missing something though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, at the end of the day they all are the same thing. The point I was trying to make is that in binary classification tasks I expect the output of the model to be (batch_size x 1) while in multi class I would expect (batch_size x 2), thus the targets for the first would be transformed from (batch_size) -> (batch_size, 1) and the second (batch_size) -> (batch_size, 2) so they match the shapes of the outputs. Still this is a preference, if you think it doesn't make sense we can just assert that the second dim of output is not less than 2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh ok nvm I thought this would lead to some issues because the targets are expected to be one-hot encoded, but in reality this is not entirely true for the binary case for a single scalar output and target. The operations performed will still be valid.

@tsanona
Copy link
Contributor Author

tsanona commented Nov 19, 2024

Yeah thank you for the changes it looks much better! I like the config solution. Just added pub to the PrecisionMetric, all else seems fine to me. Also, just out of curiosity, why isn't it possible to use the general ClassificationInput?

@tsanona tsanona requested a review from laggui November 19, 2024 00:31
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the public 😅

Also, just out of curiosity, why isn't it possible to use the general ClassificationInput?

I simply renamed it since it only applies to the precision metric for now.

Should be good to merge, just one comment left 🙂

Comment on lines +39 to +43
PrecisionInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spam vs not spam (i.e., positive vs negative) is still considered binary classification. I think what you are talking about is just that one uses sigmoid to model the positive output (so there is only one score) vs softmax where you have a score for each. But the targets are still represented as 0 or 1.

I'm not sure what self.targets.clone().unsqueeze_dim(1).bool() is supposed to represent 🤔 I could be missing something though

@tsanona tsanona requested a review from laggui November 19, 2024 19:14
Comment on lines +39 to +43
PrecisionInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh ok nvm I thought this would lead to some issues because the targets are expected to be one-hot encoded, but in reality this is not entirely true for the binary case for a single scalar output and target. The operations performed will still be valid.

@laggui laggui merged commit 76e67bf into tracel-ai:main Nov 20, 2024
11 checks passed
@tsanona tsanona deleted the add-to-metrics branch November 20, 2024 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants