Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #1121
Our implementation of the Matthews correlation coefficient does not work when using DDP. This is because
torchmetrics
will automatically concatenate the state variables from different batches when DDP is used. So when our MCCcompute
method is called, the state variables are already tensors instead of lists of tensors.torchmetrics
gets around this by using their functiondim_zero_cat
which checks if the thing to concatenate is already a tensor, see this example in cosine similarity.The case for Multiclass MCC has the same problem but also the added difficulty that we drop the batch and task dimensions and then stack along a new dimension when we need to collect batches. I've changed this to keep the dimension until
compute
. Because we dropped task dimensions early, our MulticlassMCC for multitask has been giving incorrect results. I have updated the tests to reflect the actual expected values after comparing to sklearn. (sklearn doesn't support multitask, so I calculated the MCC for each task separately and then averaged.)For reference I will add that in MulticlassMCC
p
is the number of times each class was predicted,t
is the number of times each class was the true value,c
is the number of points we got correct for each task, ands
is the number of points for each task.