Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distrib #635

Merged
merged 29 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2afc205
[WIP] Added cifar10 distributed example
vfdev-5 Aug 1, 2019
7b8eac9
[WIP] Metric with all reduce decorator and tests
vfdev-5 Aug 1, 2019
c7d2337
[WIP] Added tests for accumulation metric
vfdev-5 Aug 1, 2019
69ced1e
[WIP] Updated with reinit_is_reduced
vfdev-5 Aug 1, 2019
f2f923b
[WIP] Distrib adaptation for other metrics
vfdev-5 Aug 2, 2019
d13b985
[WIP] Warnings for EpochMetric and Precision/Recall when distrib
vfdev-5 Aug 2, 2019
e7d12d0
Updated metrics and tests to run on distributed configuration
vfdev-5 Aug 3, 2019
0a5f582
Minor fixes and cosmetics
vfdev-5 Aug 3, 2019
954269c
Merge branch 'master' into distrib
vfdev-5 Aug 3, 2019
206f2e1
Fixed bugs and improved contrib/cifar10 example
vfdev-5 Aug 3, 2019
99a6b4a
Updated docs
vfdev-5 Aug 3, 2019
3eff370
Update metrics.rst
vfdev-5 Aug 6, 2019
ad8375c
Updated docs and set device as "cuda" in distributed instead of raisi…
vfdev-5 Aug 6, 2019
0bcc287
[WIP] Fix missing _is_reduced in precision/recall with tests
vfdev-5 Aug 7, 2019
1bda698
Merge remote-tracking branch 'origin' into distrib
vfdev-5 Aug 7, 2019
7dd6937
Updated other tests
vfdev-5 Aug 7, 2019
27324dc
Merge branch 'master' into distrib
vfdev-5 Aug 29, 2019
f4a3d4b
Updated travis and renamed tbptt test gpu -> cuda
vfdev-5 Aug 29, 2019
2036075
Distrib (#573)
vfdev-5 Aug 30, 2019
69502fc
Merge branch 'distrib' of https://github.com/pytorch/ignite into distrib
vfdev-5 Sep 9, 2019
d52c36d
Merge branch 'master' into distrib
vfdev-5 Sep 9, 2019
ecb00a5
Merge branch 'master' into distrib
vfdev-5 Sep 13, 2019
71836aa
Merge branch 'master' into distrib
vfdev-5 Sep 25, 2019
46cdd86
Compute IoU, Precision, Recall based on CM on CPU
vfdev-5 Sep 26, 2019
fd14d4d
Fixes incomplete merge with 1856c8e0f1be102d4530592bcb7caac690f198c4
vfdev-5 Sep 26, 2019
59b894c
Merge branch 'master' into distrib
vfdev-5 Oct 17, 2019
80ad40a
Update distrib branch and CIFAR10 example (#647)
vfdev-5 Oct 22, 2019
8288831
Finalized Cifar10 example (#649)
vfdev-5 Oct 24, 2019
25db95b
Merge branch 'master' into distrib
vfdev-5 Oct 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge remote-tracking branch 'origin' into distrib
  • Loading branch information
vfdev-5 committed Aug 7, 2019
commit 1bda698b17667bfaaf94249978d075c0495357c9
25 changes: 11 additions & 14 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ignite.metrics import Metric, MetricsLambda
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import sync_all_reduce, reinit_is_reduced
from ignite.utils import to_onehot


class ConfusionMatrix(Metric):
Expand Down Expand Up @@ -77,26 +76,24 @@ def _check_shape(self, output):
if y_shape != y_pred_shape:
raise ValueError("y and y_pred must have compatible shapes.")

return y_pred, y

@reinit_is_reduced
def update(self, output):
self._check_shape(output)
y_pred, y = output

if y_pred.shape != y.shape:
y_ohe = to_onehot(y.reshape(-1), self.num_classes)
y_ohe_t = y_ohe.transpose(0, 1)
else:
y_ohe_t = y.transpose(0, 1).reshape(y.shape[1], -1)
y_ohe_t = y_ohe_t.to(self.confusion_matrix)
self._num_examples += y_pred.shape[0]

# target is (batch_size, ...)
y_pred = torch.argmax(y_pred, dim=1).flatten()
y = y.flatten()

indices = torch.argmax(y_pred, dim=1)
y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes)
y_pred_ohe = y_pred_ohe.to(self.confusion_matrix)
target_mask = (y >= 0) & (y < self.num_classes)
y = y[target_mask]
y_pred = y_pred[target_mask]

self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe)
self._num_examples += y_pred.shape[0]
indices = self.num_classes * y + y_pred
m = torch.bincount(indices, minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
self.confusion_matrix += m.to(self.confusion_matrix)

@sync_all_reduce('confusion_matrix', '_num_examples')
def compute(self):
Expand Down
136 changes: 2 additions & 134 deletions tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,57 +93,6 @@ def _test_N():
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def _test_NC():
num_classes = 4
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(10, num_classes)
y_labels = torch.randint(0, num_classes, size=(10,)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

num_classes = 10
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes)
y_labels = torch.randint(0, num_classes, size=(4, )).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# 2-classes
num_classes = 2
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes)
y_labels = torch.randint(0, num_classes, size=(4,)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# Batched Updates
num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(100, num_classes)
y_labels = torch.randint(0, num_classes, size=(100,)).long()
y = to_onehot(y_labels, num_classes=num_classes)

batch_size = 16
n_iters = y.shape[0] // batch_size + 1

for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

np_y = y_labels.numpy().ravel()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test_N()
Expand Down Expand Up @@ -189,47 +138,6 @@ def _test_NL():
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def _test_NCL():
num_classes = 4
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(10, num_classes, 5)
y_labels = torch.randint(0, num_classes, size=(10, 5)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

num_classes = 10
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes, 5)
y_labels = torch.randint(0, num_classes, size=(4, 5)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# Batched Updates
num_classes = 9
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(100, num_classes, 7)
y_labels = torch.randint(0, num_classes, size=(100, 7)).long()
y = to_onehot(y_labels, num_classes=num_classes)

batch_size = 16
n_iters = y.shape[0] // batch_size + 1

for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

np_y = y_labels.numpy().ravel()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test_NL()
Expand Down Expand Up @@ -274,46 +182,6 @@ def _test_NHW():
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

def _test_NCHW():
num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)

y_pred = torch.rand(4, num_classes, 12, 10)
y_labels = torch.randint(0, num_classes, size=(4, 12, 10)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

num_classes = 5
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(4, num_classes, 10, 12, 8)
y_labels = torch.randint(0, num_classes, size=(4, 10, 12, 8)).long()
y = to_onehot(y_labels, num_classes=num_classes)
cm.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y_labels.numpy().ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# Batched Updates
num_classes = 3
cm = ConfusionMatrix(num_classes=num_classes)
y_pred = torch.rand(100, num_classes, 8, 8)
y_labels = torch.randint(0, num_classes, size=(100, 8, 8)).long()
y = to_onehot(y_labels, num_classes=num_classes)

batch_size = 16
n_iters = y.shape[0] // batch_size + 1

for i in range(n_iters):
idx = i * batch_size
cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

np_y = y_labels.numpy().ravel()
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test_NHW()
Expand Down Expand Up @@ -640,7 +508,7 @@ def _gather(y):
output = (th_y_logits, th_y_true)
cm.update(output)

res = cm.compute().numpy() / dist.get_world_size()
res = cm.compute().cpu().numpy() / dist.get_world_size()

assert np.all(true_res == res)

Expand Down Expand Up @@ -674,7 +542,7 @@ def _gather(y):
# Update metric & compute
output = (th_y_logits, th_y_true)
cm.update(output)
res = cm.compute().numpy()
res = cm.compute().cpu().numpy()

# Compute confusion matrix with sklearn
th_y_true = _gather(th_y_true)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.