Skip to content

Commit

Permalink
Fixes issue pytorch#543 (pytorch#572)
Browse files Browse the repository at this point in the history
* Fixes issue pytorch#543

Previous CM implementation suffered from the problem if target contains non-contiguous indices.
New implementation is almost taken from torchvision's https://github.com/pytorch/vision/blob/master/references/segmentation/utils.py#L75-L117

This commit also removes the case of targets as (batchsize, num_categories, ...) where num_categories excludes background class.
Confusion matrix computation is possible almost similarly for (batchsize, ...), but when target is all zero (0, ..., 0)  = no classes (background class),
then confusion matrix does not count any true/false predictions.

* Update confusion_matrix.py
  • Loading branch information
vfdev-5 authored and anmolsjoshi committed Aug 4, 2019
1 parent 4d13db2 commit 1856c8e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 182 deletions.
52 changes: 28 additions & 24 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@

from ignite.metrics import Metric, MetricsLambda
from ignite.exceptions import NotComputableError
from ignite.utils import to_onehot


class ConfusionMatrix(Metric):
"""Calculates confusion matrix for multi-class data.
- `update` must receive output of the form `(y_pred, y)`.
- `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...)
- `y` can be of two types:
- shape (batch_size, num_categories, ...)
- shape (batch_size, ...) and contains ground-truth class indices
- `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
with or without the background class. During the computation, argmax of `y_pred` is taken to determine
predicted classes.
Args:
num_classes (int): number of classes. In case of images, num_classes should also count the background index 0.
num_classes (int): number of classes. See notes for more details.
average (str, optional): confusion matrix values averaging schema: None, "samples", "recall", "precision".
Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
Expand All @@ -27,6 +26,12 @@ class ConfusionMatrix(Metric):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
Note:
In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
contribute to the confusion matrix and others are neglected. For example, if `num_classes=20` and target index
equal 255 is encountered, then it is filtered out.
"""

def __init__(self, num_classes, average=None, output_transform=lambda x: x):
Expand All @@ -40,7 +45,8 @@ def __init__(self, num_classes, average=None, output_transform=lambda x: x):
super(ConfusionMatrix, self).__init__(output_transform=output_transform)

def reset(self):
self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes, dtype=torch.float)
self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes,
dtype=torch.int64, device='cpu')
self._num_examples = 0

def _check_shape(self, output):
Expand All @@ -54,9 +60,9 @@ def _check_shape(self, output):
raise ValueError("y_pred does not have correct number of categories: {} vs {}"
.format(y_pred.shape[1], self.num_classes))

if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
if not (y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y_pred must have shape (batch_size, num_categories, ...) and y must have "
"shape of (batch_size, num_categories, ...) or (batch_size, ...), "
"shape of (batch_size, ...), "
"but given {} vs {}.".format(y.shape, y_pred.shape))

y_shape = y.shape
Expand All @@ -68,38 +74,36 @@ def _check_shape(self, output):
if y_shape != y_pred_shape:
raise ValueError("y and y_pred must have compatible shapes.")

return y_pred, y

def update(self, output):
y_pred, y = self._check_shape(output)
self._check_shape(output)
y_pred, y = output

if y_pred.shape != y.shape:
y_ohe = to_onehot(y.reshape(-1), self.num_classes)
y_ohe_t = y_ohe.transpose(0, 1).float()
else:
y_ohe_t = y.transpose(0, 1).reshape(y.shape[1], -1).float()
self._num_examples += y_pred.shape[0]

indices = torch.argmax(y_pred, dim=1)
y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes)
y_pred_ohe = y_pred_ohe.float()
# target is (batch_size, ...)
y_pred = torch.argmax(y_pred, dim=1).flatten()
y = y.flatten()

if self.confusion_matrix.type() != y_ohe_t.type():
self.confusion_matrix = self.confusion_matrix.type_as(y_ohe_t)
target_mask = (y >= 0) & (y < self.num_classes)
y = y[target_mask]
y_pred = y_pred[target_mask]

self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe).float()
self._num_examples += y_pred.shape[0]
indices = self.num_classes * y + y_pred
m = torch.bincount(indices, minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
self.confusion_matrix += m.to(self.confusion_matrix)

def compute(self):
if self._num_examples == 0:
raise NotComputableError('Confusion matrix must have at least one example before it can be computed.')
if self.average:
self.confusion_matrix = self.confusion_matrix.float()
if self.average == "samples":
return self.confusion_matrix / self._num_examples
elif self.average == "recall":
return self.confusion_matrix / (self.confusion_matrix.sum(dim=1) + 1e-15)
elif self.average == "precision":
return self.confusion_matrix / (self.confusion_matrix.sum(dim=0) + 1e-15)
return self.confusion_matrix.cpu()
return self.confusion_matrix


def IoU(cm, ignore_index=None):
Expand Down
Loading

0 comments on commit 1856c8e

Please sign in to comment.