Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MCC for DDP and multitask #1131

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

KnathanM
Copy link
Member

Closes #1121

Our implementation of the Matthews correlation coefficient does not work when using DDP. This is because torchmetrics will automatically concatenate the state variables from different batches when DDP is used. So when our MCC compute method is called, the state variables are already tensors instead of lists of tensors. torchmetrics gets around this by using their function dim_zero_cat which checks if the thing to concatenate is already a tensor, see this example in cosine similarity.

The case for Multiclass MCC has the same problem but also the added difficulty that we drop the batch and task dimensions and then stack along a new dimension when we need to collect batches. I've changed this to keep the dimension until compute. Because we dropped task dimensions early, our MulticlassMCC for multitask has been giving incorrect results. I have updated the tests to reflect the actual expected values after comparing to sklearn. (sklearn doesn't support multitask, so I calculated the MCC for each task separately and then averaged.)

For reference I will add that in MulticlassMCC p is the number of times each class was predicted, t is the number of times each class was the true value, c is the number of points we got correct for each task, and s is the number of points for each task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[v2 BUG]: "TypeError in Chemprop hpopt: torch.cat() got invalid arguments during metric computation"
1 participant