Skip to content

Commit

Permalink
[WIP] Fix missing _is_reduced in precision/recall with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Aug 7, 2019
1 parent ad8375c commit 0bcc287
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def compute(self):
if not (self._type == "multilabel" and not self._average):
self._true_positives = self._sync_all_reduce(self._true_positives)
self._positives = self._sync_all_reduce(self._positives)
self._is_reduced = True

result = self._true_positives / (self._positives + self.eps)

Expand Down
11 changes: 9 additions & 2 deletions tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,13 @@ def update(engine, i):

assert "pr" in engine.state.metrics
res = engine.state.metrics['pr']
res2 = pr.compute()
if isinstance(res, torch.Tensor):
res = res.cpu().numpy()
res2 = res2.cpu().numpy()
assert (res == res2).all()
else:
assert res == res2

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
Expand All @@ -828,7 +833,9 @@ def update(engine, i):
y_pred = torch.randint(0, 2, size=(10, 5, 18, 16))
y = torch.randint(0, 2, size=(10, 5, 18, 16)).long()
pr.update((y_pred, y))
pr_compute = pr.compute()
assert len(pr_compute) == 10 * 18 * 16
pr_compute1 = pr.compute()
pr_compute2 = pr.compute()
assert len(pr_compute1) == 10 * 18 * 16
assert (pr_compute1 == pr_compute2).all()

test_distrib_itegration_multilabel()
13 changes: 10 additions & 3 deletions tests/ignite/metrics/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,16 +806,21 @@ def update(engine, i):

assert "re" in engine.state.metrics
res = engine.state.metrics['re']
res2 = re.compute()
if isinstance(res, torch.Tensor):
res = res.cpu().numpy()
res2 = res2.cpu().numpy()
assert (res == res2).all()
else:
assert res == res2

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
true_res = recall_score(to_numpy_multilabel(y_true),
to_numpy_multilabel(y_preds),
average='samples' if average else None)

assert pytest.approx(res) == true_res
assert pytest.approx(res) == true_res

for _ in range(5):
_test(average=True, n_epochs=1)
Expand All @@ -828,7 +833,9 @@ def update(engine, i):
y_pred = torch.randint(0, 2, size=(10, 5, 18, 16))
y = torch.randint(0, 2, size=(10, 5, 18, 16)).long()
re.update((y_pred, y))
re_compute = re.compute()
assert len(re_compute) == 10 * 18 * 16
re_compute1 = re.compute()
re_compute2 = re.compute()
assert len(re_compute1) == 10 * 18 * 16
assert (re_compute1 == re_compute2).all()

test_distrib_itegration_multilabel()

0 comments on commit 0bcc287

Please sign in to comment.