Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distrib #635

Merged
merged 29 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2afc205
[WIP] Added cifar10 distributed example
vfdev-5 Aug 1, 2019
7b8eac9
[WIP] Metric with all reduce decorator and tests
vfdev-5 Aug 1, 2019
c7d2337
[WIP] Added tests for accumulation metric
vfdev-5 Aug 1, 2019
69ced1e
[WIP] Updated with reinit_is_reduced
vfdev-5 Aug 1, 2019
f2f923b
[WIP] Distrib adaptation for other metrics
vfdev-5 Aug 2, 2019
d13b985
[WIP] Warnings for EpochMetric and Precision/Recall when distrib
vfdev-5 Aug 2, 2019
e7d12d0
Updated metrics and tests to run on distributed configuration
vfdev-5 Aug 3, 2019
0a5f582
Minor fixes and cosmetics
vfdev-5 Aug 3, 2019
954269c
Merge branch 'master' into distrib
vfdev-5 Aug 3, 2019
206f2e1
Fixed bugs and improved contrib/cifar10 example
vfdev-5 Aug 3, 2019
99a6b4a
Updated docs
vfdev-5 Aug 3, 2019
3eff370
Update metrics.rst
vfdev-5 Aug 6, 2019
ad8375c
Updated docs and set device as "cuda" in distributed instead of raisi…
vfdev-5 Aug 6, 2019
0bcc287
[WIP] Fix missing _is_reduced in precision/recall with tests
vfdev-5 Aug 7, 2019
1bda698
Merge remote-tracking branch 'origin' into distrib
vfdev-5 Aug 7, 2019
7dd6937
Updated other tests
vfdev-5 Aug 7, 2019
27324dc
Merge branch 'master' into distrib
vfdev-5 Aug 29, 2019
f4a3d4b
Updated travis and renamed tbptt test gpu -> cuda
vfdev-5 Aug 29, 2019
2036075
Distrib (#573)
vfdev-5 Aug 30, 2019
69502fc
Merge branch 'distrib' of https://github.com/pytorch/ignite into distrib
vfdev-5 Sep 9, 2019
d52c36d
Merge branch 'master' into distrib
vfdev-5 Sep 9, 2019
ecb00a5
Merge branch 'master' into distrib
vfdev-5 Sep 13, 2019
71836aa
Merge branch 'master' into distrib
vfdev-5 Sep 25, 2019
46cdd86
Compute IoU, Precision, Recall based on CM on CPU
vfdev-5 Sep 26, 2019
fd14d4d
Fixes incomplete merge with 1856c8e0f1be102d4530592bcb7caac690f198c4
vfdev-5 Sep 26, 2019
59b894c
Merge branch 'master' into distrib
vfdev-5 Oct 17, 2019
80ad40a
Update distrib branch and CIFAR10 example (#647)
vfdev-5 Oct 22, 2019
8288831
Finalized Cifar10 example (#649)
vfdev-5 Oct 24, 2019
25db95b
Merge branch 'master' into distrib
vfdev-5 Oct 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[WIP] Metric with all reduce decorator and tests
  • Loading branch information
vfdev-5 committed Aug 1, 2019
commit 7b8eac95612c86652e6b55c2eaf8da6436225571
28 changes: 19 additions & 9 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from ignite.metrics import Metric
from ignite.metrics.metric import sync_all_reduce
from ignite.exceptions import NotComputableError


Expand All @@ -27,20 +28,22 @@ class VariableAccumulation(Metric):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device): device specification in case of distributed computation usage.
In most of the cases, it should defined as "cuda:local_rank".

"""

def __init__(self, op, output_transform=lambda x: x):
def __init__(self, op, output_transform=lambda x: x, device=None):
if not callable(op):
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
self.accumulator = None
self.num_examples = None
self._op = op
super(VariableAccumulation, self).__init__(output_transform=output_transform)
super(VariableAccumulation, self).__init__(output_transform=output_transform, device=device)

def reset(self):
self.accumulator = torch.tensor(0.0, dtype=torch.float64)
self.num_examples = torch.tensor(0.0, dtype=torch.float64)
self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self.num_examples = torch.tensor(0.0, dtype=torch.long, device=self._device)
super(VariableAccumulation, self).reset()

def _check_output_type(self, output):
Expand All @@ -55,7 +58,8 @@ def update(self, output):
self.num_examples += output.shape[0] if len(output.shape) > 1 else 1
else:
self.num_examples += 1


@sync_all_reduce('accumulator', 'num_examples')
def compute(self):
return [self.accumulator, self.num_examples]

Expand Down Expand Up @@ -91,15 +95,18 @@ class Average(VariableAccumulation):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device): device specification in case of distributed computation usage.
In most of the cases, it should defined as "cuda:local_rank".

"""
def __init__(self, output_transform=lambda x: x):
def __init__(self, output_transform=lambda x: x, device=None):

def _mean_op(a, x):
return a + x

super(Average, self).__init__(op=_mean_op, output_transform=output_transform)
super(Average, self).__init__(op=_mean_op, output_transform=output_transform, device=device)

@sync_all_reduce('accumulator', 'num_examples')
def compute(self):
if self.num_examples < 1:
raise NotComputableError("{} must have at least one example before"
Expand Down Expand Up @@ -127,17 +134,20 @@ class GeometricAverage(VariableAccumulation):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device): device specification in case of distributed computation usage.
In most of the cases, it should defined as "cuda:local_rank".

"""
def __init__(self, output_transform=lambda x: x):
def __init__(self, output_transform=lambda x: x, device=None):

def _geom_op(a, x):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
return a + torch.log(x)

super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform)
super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform, device=device)

@sync_all_reduce('accumulator', 'num_examples')
def compute(self):
if self.num_examples < 1:
raise NotComputableError("{} must have at least one example before"
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def update(self, output):
y_pred_ohe = y_pred_ohe.float()

if self.confusion_matrix.type() != y_ohe_t.type():
self.confusion_matrix = self.confusion_matrix.type_as(y_ohe_t)
self.confusion_matrix = self.confusion_matrix.to(y_ohe_t)

self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe).float()
self.confusion_matrix += torch.matmul(y_ohe_t, y_pred_ohe)
self._num_examples += y_pred.shape[0]

def compute(self):
Expand Down
71 changes: 69 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import numbers
from abc import ABCMeta, abstractmethod
from functools import wraps

try:
from collections.abc import Sequence
except ImportError: # Python 2.7 compatibility
from collections import Sequence

import torch

from ignite._six import with_metaclass
from ignite.engine import Events
import torch


class Metric(with_metaclass(ABCMeta, object)):
Expand All @@ -13,11 +22,21 @@ class Metric(with_metaclass(ABCMeta, object)):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device): device specification in case of distributed computation usage.
In most of the cases, it should defined as "cuda:local_rank".

"""

def __init__(self, output_transform=lambda x: x):
def __init__(self, output_transform=lambda x: x, device=None):
self._output_transform = output_transform

# Check device if distributed is initialized:
if torch.distributed.is_available() and torch.distributed.is_initialized():
if device is None:
raise ValueError("Please provide the device for distributed computation. "
"In most of the cases, it should defined as 'cuda:local_rank'.")
device = torch.device(device)
self._device = device
self.reset()

@abstractmethod
Expand Down Expand Up @@ -55,6 +74,32 @@ def compute(self):
NotComputableError: raised when the metric cannot be computed.
"""
pass

def _sync_all_reduce(self, tensor):
if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
# Nothing to reduce
return tensor

tensor_to_number = False
if isinstance(tensor, numbers.Number):
tensor = torch.tensor(tensor, device=self._device)
tensor_to_number = True

# synchronize and all reduce

if isinstance(tensor, torch.Tensor):
# check if the tensor is at specified device
if tensor.device != self._device:
tensor = tensor.to(self._device)
else:
raise TypeError("Unhandled input type {}".format(type(tensor)))

torch.distributed.barrier()
torch.distributed.all_reduce(tensor)

if tensor_to_number:
return tensor.item()
return tensor

def started(self, engine):
self.reset()
Expand Down Expand Up @@ -146,3 +191,25 @@ def wrapper(*args, **kwargs):
def __getitem__(self, index):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x: x[index], self)


def sync_all_reduce(*attrs):

def wraper(func):

@wraps(func)
def another_wrapper(self, *args, **kwargs):
if not isinstance(self, Metric):
raise RuntimeError("Decorator sync_all_reduce should be used on "
"ignite.metric.Metric class methods only")

if len(attrs) > 0:
for attr in attrs:
t = getattr(self, attr, None)
if t is not None:
t = self._sync_all_reduce(t)
setattr(self, attr, t)

return func(self, *args, **kwargs)
return another_wrapper
return wraper
30 changes: 30 additions & 0 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

import torch
import torch.distributed as dist


@pytest.fixture()
def local_rank(worker_id):
""" use a different account in each xdist worker """
return int(worker_id.replace("gw", ""))


@pytest.fixture()
def distributed_context_single_node(local_rank):
# import os
# os.environ["WORLD_SIZE"] = "{}".format(torch.cuda.device_count())
# os.environ["RANK"] = "{}".format(local_rank)

dist_info = {
"backend": "nccl",
"world_size": torch.cuda.device_count(),
"rank": local_rank,
"init_method": "tcp://localhost:2222"
}

g = dist.init_process_group(**dist_info)

yield g

dist.destroy_process_group()
102 changes: 99 additions & 3 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from mock import MagicMock

import pytest
from pytest import approx, raises
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
Expand All @@ -22,7 +23,7 @@ def compute(self):

def update(self, output):
assert output == (y_pred, y)

metric = DummyMetric()
state = State(output=(y_pred, y))
engine = MagicMock(state=state)
Expand All @@ -41,8 +42,8 @@ def compute(self):
pass

def update(self, output):
assert output == (y_pred, y)

assert output == (y_pred, y)
def transform(output):
pred_dict, target_dict = output
return pred_dict['y'], target_dict['y']
Expand Down Expand Up @@ -445,3 +446,98 @@ def data(y_pred, y):
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))
labels = [1]
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))


class DummyMetric(Metric):
def reset(self):
pass

def compute(self):
pass

def update(self, output):
pass


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if no GPU")
def test_distrib_no_device_metric(distributed_context_single_node):

import torch.distributed as dist
assert dist.is_available() and dist.is_initialized()

with pytest.raises(ValueError, match=r"Please provide the device for distributed computation."):
DummyMetric()


def test__sync_all_reduce():
m = DummyMetric()
res = m._sync_all_reduce(10)
assert res == 10


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if no GPU")
def test_distrib__sync_all_reduce(local_rank, distributed_context_single_node):

import torch.distributed as dist
assert dist.is_available() and dist.is_initialized()

# This test should be the first in the list, otherwise stucked
m = DummyMetric(device="cuda:{}".format(local_rank))
t = torch.tensor(10, device="cuda:1")
res = m._sync_all_reduce(t)
assert res.item() == 10 * dist.get_world_size()

m = DummyMetric(device="cuda:{}".format(local_rank))
res = m._sync_all_reduce(10)
assert res == 10 * dist.get_world_size()

m = DummyMetric(device="cuda:{}".format(local_rank))
t = torch.tensor(10, device="cuda:{}".format(local_rank))
res = m._sync_all_reduce(t)
assert res.item() == 10 * dist.get_world_size()

m = DummyMetric(device="cuda:{}".format(local_rank))
with pytest.raises(TypeError, match=r"Unhandled input type"):
m._sync_all_reduce("abc")


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if no GPU")
def test_distrib_sync_all_reduce_decorator(local_rank, distributed_context_single_node):

import torch.distributed as dist
assert dist.is_available() and dist.is_initialized()

from ignite.metrics.metric import sync_all_reduce

class DummyMetric(Metric):

def reset(self):
self.a = torch.tensor([0, 1, 2, 3], device=self._device, requires_grad=False)
self.a_nocomp = self.a.clone().to('cpu')
self.b = torch.tensor(1.0, dtype=torch.float64, device=self._device, requires_grad=False)
self.b_nocomp = self.b.clone().to("cpu")
self.c = 0.0
self.c_nocomp = self.c
self.n = 0
self.n_nocomp = self.n

@sync_all_reduce("a", "b", "c", "n")
def compute(self):
assert (self.a.cpu() == (self.a_nocomp + 10) * dist.get_world_size()).all()
assert (self.b.cpu() == (self.b_nocomp - 5) * dist.get_world_size()).all()
assert pytest.approx(self.c == (self.c_nocomp + 1.23456) * dist.get_world_size())
assert pytest.approx(self.n == (self.n_nocomp + 1) * dist.get_world_size())

def update(self, output):
self.n += 1
self.c += 1.23456
self.a += 10.0
self.b -= 5.0

m = DummyMetric(device="cuda:{}".format(local_rank))
m.update(None)
m.compute()

4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[flake8]
max-line-length = 120
ignore = E305,E402,E721,E722,E741,F401,F403,F405,F821,F841,F999

[pytest]
markers =
distributed: mark a test with distributed option