Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into distrib
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Aug 7, 2019
2 parents 0bcc287 + 1856c8e commit 1bda698
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 167 deletions.
43 changes: 21 additions & 22 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@
from ignite.metrics import Metric, MetricsLambda
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import sync_all_reduce, reinit_is_reduced
from ignite.utils import to_onehot


class ConfusionMatrix(Metric):
"""Calculates confusion matrix for multi-class data.
- `update` must receive output of the form `(y_pred, y)`.
- `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...)
- `y` can be of two types:
- shape (batch_size, num_categories, ...)
- shape (batch_size, ...) and contains ground-truth class indices
- `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
with or without the background class. During the computation, argmax of `y_pred` is taken to determine
predicted classes.
Args:
num_classes (int): number of classes. In case of images, num_classes should also count the background index 0.
num_classes (int): number of classes. See notes for more details.
average (str, optional): confusion matrix values averaging schema: None, "samples", "recall", "precision".
Default is None. If `average="samples"` then confusion matrix values are normalized by the number of seen
samples. If `average="recall"` then confusion matrix values are normalized such that diagonal values
Expand Down Expand Up @@ -63,9 +62,9 @@ def _check_shape(self, output):
raise ValueError("y_pred does not have correct number of categories: {} vs {}"
.format(y_pred.shape[1], self.num_classes))

if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
if not (y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y_pred must have shape (batch_size, num_categories, ...) and y must have "
"shape of (batch_size, num_categories, ...) or (batch_size, ...), "
"shape of (batch_size, ...), "
"but given {} vs {}.".format(y.shape, y_pred.shape))

y_shape = y.shape
Expand All @@ -77,38 +76,38 @@ def _check_shape(self, output):
if y_shape != y_pred_shape:
raise ValueError("y and y_pred must have compatible shapes.")

return y_pred, y

@reinit_is_reduced
def update(self, output):
y_pred, y = self._check_shape(output)
self._check_shape(output)
y_pred, y = output

if y_pred.shape != y.shape:
y_ohe = to_onehot(y.reshape(-1), self.num_classes)
y_ohe_t = y_ohe.transpose(0, 1)
else:
y_ohe_t = y.transpose(0, 1).reshape(y.shape[1], -1)
y_ohe_t = y_ohe_t.to(self.confusion_matrix)
self._num_examples += y_pred.shape[0]

indices = torch.argmax(y_pred, dim=1)
y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes)
y_pred_ohe = y_pred_ohe.to(self.confusion_matrix)
# target is (batch_size, ...)
y_pred = torch.argmax(y_pred, dim=1).flatten()
y = y.flatten()

self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe)
self._num_examples += y_pred.shape[0]
target_mask = (y >= 0) & (y < self.num_classes)
y = y[target_mask]
y_pred = y_pred[target_mask]

indices = self.num_classes * y + y_pred
m = torch.bincount(indices, minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
self.confusion_matrix += m.to(self.confusion_matrix)

@sync_all_reduce('confusion_matrix', '_num_examples')
def compute(self):
if self._num_examples == 0:
raise NotComputableError('Confusion matrix must have at least one example before it can be computed.')
if self.average:
self.confusion_matrix = self.confusion_matrix.float()
if self.average == "samples":
return self.confusion_matrix / self._num_examples
elif self.average == "recall":
return self.confusion_matrix / (self.confusion_matrix.sum(dim=1) + 1e-15)
elif self.average == "precision":
return self.confusion_matrix / (self.confusion_matrix.sum(dim=0) + 1e-15)
return self.confusion_matrix.cpu()
return self.confusion_matrix


def IoU(cm, ignore_index=None):
Expand Down
162 changes: 17 additions & 145 deletions tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ignite.exceptions import NotComputableError
from ignite.metrics import ConfusionMatrix, IoU, mIoU
from ignite.metrics.confusion_matrix import cmAccuracy, cmPrecision, cmRecall
from ignite.utils import to_onehot

import pytest

Expand All @@ -25,23 +24,19 @@ def test_multiclass_wrong_inputs():
cm = ConfusionMatrix(10)

with pytest.raises(ValueError, match=r"y_pred must have shape \(batch_size, num_categories, ...\)"):
# incompatible shapes
cm.update((torch.rand(10),
torch.randint(0, 2, size=(10,)).long()))

with pytest.raises(ValueError, match=r"y_pred does not have correct number of categories:"):
# incompatible shapes
cm.update((torch.rand(10, 5, 4),
torch.randint(0, 2, size=(10,)).long()))

with pytest.raises(ValueError, match=r"y_pred must have shape \(batch_size, num_categories, ...\) "
r"and y must have "):
# incompatible shapes
cm.update((torch.rand(4, 10, 12, 12),
torch.randint(0, 10, size=(10, )).long()))

with pytest.raises(ValueError, match=r"y and y_pred must have compatible shapes."):
# incompatible shapes
cm.update((torch.rand(4, 10, 12, 14),
torch.randint(0, 10, size=(4, 5, 6)).long()))

Expand All @@ -50,7 +45,7 @@ def test_multiclass_wrong_inputs():


def test_multiclass_input_N():
# Multiclass input data of shape (N, ) and (N, C)
# Multiclass input data of shape (N, )
def _test_N():
num_classes = 4
cm = ConfusionMatrix(num_classes=num_classes)
Expand Down Expand Up @@ -98,65 +93,13 @@ def _test_N():
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def _test_NC():
num_classes = 4
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(10, num_classes)
y_labels = torch.randint(0, num_classes, size=(10,)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

num_classes = 10
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes)
y_labels = torch.randint(0, num_classes, size=(4, )).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# 2-classes
num_classes = 2
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes)
y_labels = torch.randint(0, num_classes, size=(4,)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# Batched Updates
num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(100, num_classes)
y_labels = torch.randint(0, num_classes, size=(100,)).long()
y = to_onehot(y_labels, num_classes=num_classes)

batch_size = 16
n_iters = y.shape[0] // batch_size + 1

for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

np_y = y_labels.numpy().ravel()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test_N()
_test_NC()


def test_multiclass_input_NL():
# Multiclass input data of shape (N, L) and (N, C, L)
# Multiclass input data of shape (N, L)
def _test_NL():
num_classes = 4
cm = ConfusionMatrix(num_classes=num_classes)
Expand Down Expand Up @@ -195,55 +138,13 @@ def _test_NL():
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def _test_NCL():
num_classes = 4
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(10, num_classes, 5)
y_labels = torch.randint(0, num_classes, size=(10, 5)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

num_classes = 10
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes, 5)
y_labels = torch.randint(0, num_classes, size=(4, 5)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# Batched Updates
num_classes = 9
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(100, num_classes, 7)
y_labels = torch.randint(0, num_classes, size=(100, 7)).long()
y = to_onehot(y_labels, num_classes=num_classes)

batch_size = 16
n_iters = y.shape[0] // batch_size + 1

for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

np_y = y_labels.numpy().ravel()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test_NL()
_test_NCL()


def test_multiclass_input_NHW():
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
# Multiclass input data of shape (N, H, W, ...)
def _test_NHW():
num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)
Expand Down Expand Up @@ -281,50 +182,21 @@ def _test_NHW():
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def _test_NCHW():
num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(4, num_classes, 12, 10)
y_labels = torch.randint(0, num_classes, size=(4, 12, 10)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes, 10, 12, 8)
y_labels = torch.randint(0, num_classes, size=(4, 10, 12, 8)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# Batched Updates
num_classes = 3
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(100, num_classes, 8, 8)
y_labels = torch.randint(0, num_classes, size=(100, 8, 8)).long()
y = to_onehot(y_labels, num_classes=num_classes)

batch_size = 16
n_iters = y.shape[0] // batch_size + 1

for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

np_y = y_labels.numpy().ravel()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test_NHW()
_test_NCHW()


def test_ignored_out_of_num_classes_indices():
num_classes = 21
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(4, num_classes, 12, 10)
y = torch.randint(0, 255, size=(4, 12, 10)).long()
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())


def get_y_true_y_pred():
Expand Down Expand Up @@ -636,7 +508,7 @@ def _gather(y):
output = (th_y_logits, th_y_true)
cm.update(output)

res = cm.compute().numpy() / dist.get_world_size()
res = cm.compute().cpu().numpy() / dist.get_world_size()

assert np.all(true_res == res)

Expand Down Expand Up @@ -670,7 +542,7 @@ def _gather(y):
# Update metric & compute
output = (th_y_logits, th_y_true)
cm.update(output)
res = cm.compute().numpy()
res = cm.compute().cpu().numpy()

# Compute confusion matrix with sklearn
th_y_true = _gather(th_y_true)
Expand Down

0 comments on commit 1bda698

Please sign in to comment.