diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index a546af60859..4a2df6c683b 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -5,7 +5,6 @@ from ignite.metrics import Metric, MetricsLambda from ignite.exceptions import NotComputableError from ignite.metrics.metric import sync_all_reduce, reinit_is_reduced -from ignite.utils import to_onehot class ConfusionMatrix(Metric): @@ -13,12 +12,12 @@ class ConfusionMatrix(Metric): - `update` must receive output of the form `(y_pred, y)`. - `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...) - - `y` can be of two types: - - shape (batch_size, num_categories, ...) - - shape (batch_size, ...) and contains ground-truth class indices + - `y` should have the following shape (batch_size, ...) and contains ground-truth class indices + with or without the background class. During the computation, argmax of `y_pred` is taken to determine + predicted classes. Args: - num_classes (int): number of classes. In case of images, num_classes should also count the background index 0. + num_classes (int): number of classes. See notes for more details. average (str, optional): confusion matrix values averaging schema: None, "samples", "recall", "precision". Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values @@ -63,9 +62,9 @@ def _check_shape(self, output): raise ValueError("y_pred does not have correct number of categories: {} vs {}" .format(y_pred.shape[1], self.num_classes)) - if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()): + if not (y.ndimension() + 1 == y_pred.ndimension()): raise ValueError("y_pred must have shape (batch_size, num_categories, ...) and y must have " - "shape of (batch_size, num_categories, ...) or (batch_size, ...), " + "shape of (batch_size, ...), " "but given {} vs {}.".format(y.shape, y_pred.shape)) y_shape = y.shape @@ -77,38 +76,38 @@ def _check_shape(self, output): if y_shape != y_pred_shape: raise ValueError("y and y_pred must have compatible shapes.") - return y_pred, y - @reinit_is_reduced def update(self, output): - y_pred, y = self._check_shape(output) + self._check_shape(output) + y_pred, y = output - if y_pred.shape != y.shape: - y_ohe = to_onehot(y.reshape(-1), self.num_classes) - y_ohe_t = y_ohe.transpose(0, 1) - else: - y_ohe_t = y.transpose(0, 1).reshape(y.shape[1], -1) - y_ohe_t = y_ohe_t.to(self.confusion_matrix) + self._num_examples += y_pred.shape[0] - indices = torch.argmax(y_pred, dim=1) - y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes) - y_pred_ohe = y_pred_ohe.to(self.confusion_matrix) + # target is (batch_size, ...) + y_pred = torch.argmax(y_pred, dim=1).flatten() + y = y.flatten() - self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe) - self._num_examples += y_pred.shape[0] + target_mask = (y >= 0) & (y < self.num_classes) + y = y[target_mask] + y_pred = y_pred[target_mask] + + indices = self.num_classes * y + y_pred + m = torch.bincount(indices, minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) + self.confusion_matrix += m.to(self.confusion_matrix) @sync_all_reduce('confusion_matrix', '_num_examples') def compute(self): if self._num_examples == 0: raise NotComputableError('Confusion matrix must have at least one example before it can be computed.') if self.average: + self.confusion_matrix = self.confusion_matrix.float() if self.average == "samples": return self.confusion_matrix / self._num_examples elif self.average == "recall": return self.confusion_matrix / (self.confusion_matrix.sum(dim=1) + 1e-15) elif self.average == "precision": return self.confusion_matrix / (self.confusion_matrix.sum(dim=0) + 1e-15) - return self.confusion_matrix.cpu() + return self.confusion_matrix def IoU(cm, ignore_index=None): diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index 3a20593b76c..bd53af83660 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -7,7 +7,6 @@ from ignite.exceptions import NotComputableError from ignite.metrics import ConfusionMatrix, IoU, mIoU from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall -from ignite.utils import to_onehot import pytest @@ -25,23 +24,19 @@ def test_multiclass_wrong_inputs(): cm = ConfusionMatrix(10) with pytest.raises(ValueError, match=r"y_pred must have shape \(batch_size, num_categories, ...\)"): - # incompatible shapes cm.update((torch.rand(10), torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError, match=r"y_pred does not have correct number of categories:"): - # incompatible shapes cm.update((torch.rand(10, 5, 4), torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError, match=r"y_pred must have shape \(batch_size, num_categories, ...\) " r"and y must have "): - # incompatible shapes cm.update((torch.rand(4, 10, 12, 12), torch.randint(0, 10, size=(10, )).long())) with pytest.raises(ValueError, match=r"y and y_pred must have compatible shapes."): - # incompatible shapes cm.update((torch.rand(4, 10, 12, 14), torch.randint(0, 10, size=(4, 5, 6)).long())) @@ -50,7 +45,7 @@ def test_multiclass_wrong_inputs(): def test_multiclass_input_N(): - # Multiclass input data of shape (N, ) and (N, C) + # Multiclass input data of shape (N, ) def _test_N(): num_classes = 4 cm = ConfusionMatrix(num_classes=num_classes) @@ -98,65 +93,13 @@ def _test_N(): np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - def _test_NC(): - num_classes = 4 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(10, num_classes) - y_labels = torch.randint(0, num_classes, size=(10,)).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - num_classes = 10 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes) - y_labels = torch.randint(0, num_classes, size=(4, )).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # 2-classes - num_classes = 2 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes) - y_labels = torch.randint(0, num_classes, size=(4,)).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # Batched Updates - num_classes = 5 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(100, num_classes) - y_labels = torch.randint(0, num_classes, size=(100,)).long() - y = to_onehot(y_labels, num_classes=num_classes) - - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 - - for i in range(n_iters): - idx = i * batch_size - cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) - - np_y = y_labels.numpy().ravel() - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - # check multiple random inputs as random exact occurencies are rare for _ in range(10): _test_N() - _test_NC() def test_multiclass_input_NL(): - # Multiclass input data of shape (N, L) and (N, C, L) + # Multiclass input data of shape (N, L) def _test_NL(): num_classes = 4 cm = ConfusionMatrix(num_classes=num_classes) @@ -195,55 +138,13 @@ def _test_NL(): np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - def _test_NCL(): - num_classes = 4 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(10, num_classes, 5) - y_labels = torch.randint(0, num_classes, size=(10, 5)).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - num_classes = 10 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes, 5) - y_labels = torch.randint(0, num_classes, size=(4, 5)).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # Batched Updates - num_classes = 9 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(100, num_classes, 7) - y_labels = torch.randint(0, num_classes, size=(100, 7)).long() - y = to_onehot(y_labels, num_classes=num_classes) - - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 - - for i in range(n_iters): - idx = i * batch_size - cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) - - np_y = y_labels.numpy().ravel() - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - # check multiple random inputs as random exact occurencies are rare for _ in range(10): _test_NL() - _test_NCL() def test_multiclass_input_NHW(): - # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) + # Multiclass input data of shape (N, H, W, ...) def _test_NHW(): num_classes = 5 cm = ConfusionMatrix(num_classes=num_classes) @@ -281,50 +182,21 @@ def _test_NHW(): np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - def _test_NCHW(): - num_classes = 5 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(4, num_classes, 12, 10) - y_labels = torch.randint(0, num_classes, size=(4, 12, 10)).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - num_classes = 5 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes, 10, 12, 8) - y_labels = torch.randint(0, num_classes, size=(4, 10, 12, 8)).long() - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # Batched Updates - num_classes = 3 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(100, num_classes, 8, 8) - y_labels = torch.randint(0, num_classes, size=(100, 8, 8)).long() - y = to_onehot(y_labels, num_classes=num_classes) - - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 - - for i in range(n_iters): - idx = i * batch_size - cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) - - np_y = y_labels.numpy().ravel() - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - # check multiple random inputs as random exact occurencies are rare for _ in range(10): _test_NHW() - _test_NCHW() + + +def test_ignored_out_of_num_classes_indices(): + num_classes = 21 + cm = ConfusionMatrix(num_classes=num_classes) + + y_pred = torch.rand(4, num_classes, 12, 10) + y = torch.randint(0, 255, size=(4, 12, 10)).long() + cm.update((y_pred, y)) + np_y_pred = y_pred.numpy().argmax(axis=1).ravel() + np_y = y.numpy().ravel() + assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) def get_y_true_y_pred(): @@ -636,7 +508,7 @@ def _gather(y): output = (th_y_logits, th_y_true) cm.update(output) - res = cm.compute().numpy() / dist.get_world_size() + res = cm.compute().cpu().numpy() / dist.get_world_size() assert np.all(true_res == res) @@ -670,7 +542,7 @@ def _gather(y): # Update metric & compute output = (th_y_logits, th_y_true) cm.update(output) - res = cm.compute().numpy() + res = cm.compute().cpu().numpy() # Compute confusion matrix with sklearn th_y_true = _gather(th_y_true)