diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 3546796a28d..91fa1df4a21 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -4,7 +4,6 @@ from ignite.metrics import Metric, MetricsLambda from ignite.exceptions import NotComputableError -from ignite.utils import to_onehot class ConfusionMatrix(Metric): @@ -12,12 +11,12 @@ class ConfusionMatrix(Metric): - `update` must receive output of the form `(y_pred, y)`. - `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...) - - `y` can be of two types: - - shape (batch_size, num_categories, ...) - - shape (batch_size, ...) and contains ground-truth class indices + - `y` should have the following shape (batch_size, ...) and contains ground-truth class indices + with or without the background class. During the computation, argmax of `y_pred` is taken to determine + predicted classes. Args: - num_classes (int): number of classes. In case of images, num_classes should also count the background index 0. + num_classes (int): number of classes. See notes for more details. average (str, optional): confusion matrix values averaging schema: None, "samples", "recall", "precision". Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values @@ -27,6 +26,12 @@ class ConfusionMatrix(Metric): :class:`~ignite.engine.Engine`'s `process_function`'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. + + Note: + In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only + contribute to the confusion matrix and others are neglected. For example, if `num_classes=20` and target index + equal 255 is encountered, then it is filtered out. + """ def __init__(self, num_classes, average=None, output_transform=lambda x: x): @@ -40,7 +45,8 @@ def __init__(self, num_classes, average=None, output_transform=lambda x: x): super(ConfusionMatrix, self).__init__(output_transform=output_transform) def reset(self): - self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes, dtype=torch.float) + self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes, + dtype=torch.int64, device='cpu') self._num_examples = 0 def _check_shape(self, output): @@ -54,9 +60,9 @@ def _check_shape(self, output): raise ValueError("y_pred does not have correct number of categories: {} vs {}" .format(y_pred.shape[1], self.num_classes)) - if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()): + if not (y.ndimension() + 1 == y_pred.ndimension()): raise ValueError("y_pred must have shape (batch_size, num_categories, ...) and y must have " - "shape of (batch_size, num_categories, ...) or (batch_size, ...), " + "shape of (batch_size, ...), " "but given {} vs {}.".format(y.shape, y_pred.shape)) y_shape = y.shape @@ -68,38 +74,36 @@ def _check_shape(self, output): if y_shape != y_pred_shape: raise ValueError("y and y_pred must have compatible shapes.") - return y_pred, y - def update(self, output): - y_pred, y = self._check_shape(output) + self._check_shape(output) + y_pred, y = output - if y_pred.shape != y.shape: - y_ohe = to_onehot(y.reshape(-1), self.num_classes) - y_ohe_t = y_ohe.transpose(0, 1).float() - else: - y_ohe_t = y.transpose(0, 1).reshape(y.shape[1], -1).float() + self._num_examples += y_pred.shape[0] - indices = torch.argmax(y_pred, dim=1) - y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes) - y_pred_ohe = y_pred_ohe.float() + # target is (batch_size, ...) + y_pred = torch.argmax(y_pred, dim=1).flatten() + y = y.flatten() - if self.confusion_matrix.type() != y_ohe_t.type(): - self.confusion_matrix = self.confusion_matrix.type_as(y_ohe_t) + target_mask = (y >= 0) & (y < self.num_classes) + y = y[target_mask] + y_pred = y_pred[target_mask] - self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe).float() - self._num_examples += y_pred.shape[0] + indices = self.num_classes * y + y_pred + m = torch.bincount(indices, minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) + self.confusion_matrix += m.to(self.confusion_matrix) def compute(self): if self._num_examples == 0: raise NotComputableError('Confusion matrix must have at least one example before it can be computed.') if self.average: + self.confusion_matrix = self.confusion_matrix.float() if self.average == "samples": return self.confusion_matrix / self._num_examples elif self.average == "recall": return self.confusion_matrix / (self.confusion_matrix.sum(dim=1) + 1e-15) elif self.average == "precision": return self.confusion_matrix / (self.confusion_matrix.sum(dim=0) + 1e-15) - return self.confusion_matrix.cpu() + return self.confusion_matrix def IoU(cm, ignore_index=None): diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index e0d84b19ad3..f91f4b149bd 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -7,7 +7,6 @@ from ignite.exceptions import NotComputableError from ignite.metrics import ConfusionMatrix, IoU, mIoU from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall -from ignite.utils import to_onehot import pytest @@ -25,37 +24,33 @@ def test_multiclass_wrong_inputs(): cm = ConfusionMatrix(10) with pytest.raises(ValueError, match=r"y_pred must have shape \(batch_size, num_categories, ...\)"): - # incompatible shapes cm.update((torch.rand(10), - torch.randint(0, 2, size=(10,)).type(torch.LongTensor))) + torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError, match=r"y_pred does not have correct number of categories:"): - # incompatible shapes cm.update((torch.rand(10, 5, 4), - torch.randint(0, 2, size=(10,)).type(torch.LongTensor))) + torch.randint(0, 2, size=(10,)).long())) with pytest.raises(ValueError, match=r"y_pred must have shape \(batch_size, num_categories, ...\) " r"and y must have "): - # incompatible shapes cm.update((torch.rand(4, 10, 12, 12), - torch.randint(0, 10, size=(10, )).type(torch.LongTensor))) + torch.randint(0, 10, size=(10, )).long())) with pytest.raises(ValueError, match=r"y and y_pred must have compatible shapes."): - # incompatible shapes cm.update((torch.rand(4, 10, 12, 14), - torch.randint(0, 10, size=(4, 5, 6)).type(torch.LongTensor))) + torch.randint(0, 10, size=(4, 5, 6)).long())) with pytest.raises(ValueError, match=r"Argument average can None or one of"): ConfusionMatrix(num_classes=10, average="abc") def test_multiclass_input_N(): - # Multiclass input data of shape (N, ) and (N, C) + # Multiclass input data of shape (N, ) def _test_N(): num_classes = 4 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(10, num_classes) - y = torch.randint(0, num_classes, size=(10,)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(10,)).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -64,7 +59,7 @@ def _test_N(): num_classes = 10 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes) - y = torch.randint(0, num_classes, size=(4, )).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(4, )).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -74,7 +69,7 @@ def _test_N(): num_classes = 2 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes) - y = torch.randint(0, num_classes, size=(4,)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(4,)).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -85,7 +80,7 @@ def _test_N(): cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(100, num_classes) - y = torch.randint(0, num_classes, size=(100,)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(100,)).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 @@ -98,71 +93,19 @@ def _test_N(): np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - def _test_NC(): - num_classes = 4 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(10, num_classes) - y_labels = torch.randint(0, num_classes, size=(10,)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - num_classes = 10 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes) - y_labels = torch.randint(0, num_classes, size=(4, )).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # 2-classes - num_classes = 2 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes) - y_labels = torch.randint(0, num_classes, size=(4,)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # Batched Updates - num_classes = 5 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(100, num_classes) - y_labels = torch.randint(0, num_classes, size=(100,)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 - - for i in range(n_iters): - idx = i * batch_size - cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) - - np_y = y_labels.numpy().ravel() - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - # check multiple random inputs as random exact occurencies are rare for _ in range(10): _test_N() - _test_NC() def test_multiclass_input_NL(): - # Multiclass input data of shape (N, L) and (N, C, L) + # Multiclass input data of shape (N, L) def _test_NL(): num_classes = 4 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(10, num_classes, 5) - y = torch.randint(0, num_classes, size=(10, 5)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(10, 5)).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -171,7 +114,7 @@ def _test_NL(): num_classes = 10 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes, 5) - y = torch.randint(0, num_classes, size=(4, 5)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(4, 5)).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -182,7 +125,7 @@ def _test_NL(): cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(100, num_classes, 7) - y = torch.randint(0, num_classes, size=(100, 7)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(100, 7)).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 @@ -195,61 +138,19 @@ def _test_NL(): np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - def _test_NCL(): - num_classes = 4 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(10, num_classes, 5) - y_labels = torch.randint(0, num_classes, size=(10, 5)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - num_classes = 10 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes, 5) - y_labels = torch.randint(0, num_classes, size=(4, 5)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # Batched Updates - num_classes = 9 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(100, num_classes, 7) - y_labels = torch.randint(0, num_classes, size=(100, 7)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 - - for i in range(n_iters): - idx = i * batch_size - cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) - - np_y = y_labels.numpy().ravel() - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - # check multiple random inputs as random exact occurencies are rare for _ in range(10): _test_NL() - _test_NCL() def test_multiclass_input_NHW(): - # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) + # Multiclass input data of shape (N, H, W, ...) def _test_NHW(): num_classes = 5 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes, 12, 10) - y = torch.randint(0, num_classes, size=(4, 12, 10)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(4, 12, 10)).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -258,7 +159,7 @@ def _test_NHW(): num_classes = 5 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes, 10, 12, 8) - y = torch.randint(0, num_classes, size=(4, 10, 12, 8)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(4, 10, 12, 8)).long() cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel() @@ -268,7 +169,7 @@ def _test_NHW(): num_classes = 3 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(100, num_classes, 8, 8) - y = torch.randint(0, num_classes, size=(100, 8, 8)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(100, 8, 8)).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 @@ -281,50 +182,21 @@ def _test_NHW(): np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - def _test_NCHW(): - num_classes = 5 - cm = ConfusionMatrix(num_classes=num_classes) - - y_pred = torch.rand(4, num_classes, 12, 10) - y_labels = torch.randint(0, num_classes, size=(4, 12, 10)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - num_classes = 5 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(4, num_classes, 10, 12, 8) - y_labels = torch.randint(0, num_classes, size=(4, 10, 12, 8)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - cm.update((y_pred, y)) - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - np_y = y_labels.numpy().ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - - # Batched Updates - num_classes = 3 - cm = ConfusionMatrix(num_classes=num_classes) - y_pred = torch.rand(100, num_classes, 8, 8) - y_labels = torch.randint(0, num_classes, size=(100, 8, 8)).type(torch.LongTensor) - y = to_onehot(y_labels, num_classes=num_classes) - - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 - - for i in range(n_iters): - idx = i * batch_size - cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) - - np_y = y_labels.numpy().ravel() - np_y_pred = y_pred.numpy().argmax(axis=1).ravel() - assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) - # check multiple random inputs as random exact occurencies are rare for _ in range(10): _test_NHW() - _test_NCHW() + + +def test_ignored_out_of_num_classes_indices(): + num_classes = 21 + cm = ConfusionMatrix(num_classes=num_classes) + + y_pred = torch.rand(4, num_classes, 12, 10) + y = torch.randint(0, 255, size=(4, 12, 10)).long() + cm.update((y_pred, y)) + np_y_pred = y_pred.numpy().argmax(axis=1).ravel() + np_y = y.numpy().ravel() + assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) def get_y_true_y_pred(): @@ -581,7 +453,7 @@ def test_cm_recall(): def test_cm_with_average(): num_classes = 5 y_pred = torch.rand(20, num_classes) - y = torch.randint(0, num_classes, size=(20,)).type(torch.LongTensor) + y = torch.randint(0, num_classes, size=(20,)).long() np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y.numpy().ravel()