Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distrib #635

Merged
merged 29 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2afc205
[WIP] Added cifar10 distributed example
vfdev-5 Aug 1, 2019
7b8eac9
[WIP] Metric with all reduce decorator and tests
vfdev-5 Aug 1, 2019
c7d2337
[WIP] Added tests for accumulation metric
vfdev-5 Aug 1, 2019
69ced1e
[WIP] Updated with reinit_is_reduced
vfdev-5 Aug 1, 2019
f2f923b
[WIP] Distrib adaptation for other metrics
vfdev-5 Aug 2, 2019
d13b985
[WIP] Warnings for EpochMetric and Precision/Recall when distrib
vfdev-5 Aug 2, 2019
e7d12d0
Updated metrics and tests to run on distributed configuration
vfdev-5 Aug 3, 2019
0a5f582
Minor fixes and cosmetics
vfdev-5 Aug 3, 2019
954269c
Merge branch 'master' into distrib
vfdev-5 Aug 3, 2019
206f2e1
Fixed bugs and improved contrib/cifar10 example
vfdev-5 Aug 3, 2019
99a6b4a
Updated docs
vfdev-5 Aug 3, 2019
3eff370
Update metrics.rst
vfdev-5 Aug 6, 2019
ad8375c
Updated docs and set device as "cuda" in distributed instead of raisi…
vfdev-5 Aug 6, 2019
0bcc287
[WIP] Fix missing _is_reduced in precision/recall with tests
vfdev-5 Aug 7, 2019
1bda698
Merge remote-tracking branch 'origin' into distrib
vfdev-5 Aug 7, 2019
7dd6937
Updated other tests
vfdev-5 Aug 7, 2019
27324dc
Merge branch 'master' into distrib
vfdev-5 Aug 29, 2019
f4a3d4b
Updated travis and renamed tbptt test gpu -> cuda
vfdev-5 Aug 29, 2019
2036075
Distrib (#573)
vfdev-5 Aug 30, 2019
69502fc
Merge branch 'distrib' of https://github.com/pytorch/ignite into distrib
vfdev-5 Sep 9, 2019
d52c36d
Merge branch 'master' into distrib
vfdev-5 Sep 9, 2019
ecb00a5
Merge branch 'master' into distrib
vfdev-5 Sep 13, 2019
71836aa
Merge branch 'master' into distrib
vfdev-5 Sep 25, 2019
46cdd86
Compute IoU, Precision, Recall based on CM on CPU
vfdev-5 Sep 26, 2019
fd14d4d
Fixes incomplete merge with 1856c8e0f1be102d4530592bcb7caac690f198c4
vfdev-5 Sep 26, 2019
59b894c
Merge branch 'master' into distrib
vfdev-5 Oct 17, 2019
80ad40a
Update distrib branch and CIFAR10 example (#647)
vfdev-5 Oct 22, 2019
8288831
Finalized Cifar10 example (#649)
vfdev-5 Oct 24, 2019
25db95b
Merge branch 'master' into distrib
vfdev-5 Oct 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update distrib branch and CIFAR10 example (#647)
* Added tests with gloo, minor updates and fixes

* Added single/multi node tests with gloo and [WIP] with nccl

* Added tests for multi-node nccl, improved examples/contrib/cifar10 example

* Experiments: 1n1gpu, 1n2gpus, 2n2gpus

* Fix flake8

* Fixes #645 (#646)

- fix CI and improve create_lr_scheduler_with_warmup

* Fix tests for python 2.7
  • Loading branch information
vfdev-5 authored Oct 22, 2019
commit 80ad40a932827e1bf93524886f5bfa76dd0ed5db
8 changes: 4 additions & 4 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ specific condition (e.g. ignore user-defined classes):
from ignite.exceptions import NotComputableError

# These decorators helps with distributed settings
from ignite.metrics.metric import sync_all_reduce, reinit_is_reduced
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced


class CustomAccuracy(Metric):
Expand All @@ -117,13 +117,13 @@ specific condition (e.g. ignore user-defined classes):
self._num_examples = None
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)

@reinit_is_reduced
@reinit__is_reduced
def reset(self):
self._num_correct = 0
self._num_examples = 0
super(CustomAccuracy, self).reset()

@reinit_is_reduced
@reinit__is_reduced
def update(self, output):
y_pred, y = output

Expand Down Expand Up @@ -178,7 +178,7 @@ We can check this implementation in a simple case:
Metrics and distributed computations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the above example, `CustomAccuracy` constructor has `device` argument and `reset`, `update`, `compute` methods are decorated with `reinit_is_reduced`, `sync_all_reduce`. The purpose of these features is to adapt metrics in distributed computations on CUDA devices and assuming the backend to support `"all_reduce" operation <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce>`_. User can specify the device (by default, `cuda`) at metric's initialization. This device _can_ be used to store internal variables on and to collect all results from all participating devices. More precisely, in the above example we added `@sync_all_reduce("_num_examples", "_num_correct")` over `compute` method. This means that when `compute` method is called, metric's interal variables `self._num_examples` and `self._num_correct` are summed up over all participating devices. Therefore, once collected, these internal variables can be used to compute the final metric value.
In the above example, `CustomAccuracy` constructor has `device` argument and `reset`, `update`, `compute` methods are decorated with `reinit__is_reduced`, `sync_all_reduce`. The purpose of these features is to adapt metrics in distributed computations on CUDA devices and assuming the backend to support `"all_reduce" operation <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce>`_. User can specify the device (by default, `cuda`) at metric's initialization. This device _can_ be used to store internal variables on and to collect all results from all participating devices. More precisely, in the above example we added `@sync_all_reduce("_num_examples", "_num_correct")` over `compute` method. This means that when `compute` method is called, metric's interal variables `self._num_examples` and `self._num_correct` are summed up over all participating devices. Therefore, once collected, these internal variables can be used to compute the final metric value.


Complete list of metrics
Expand Down
6 changes: 5 additions & 1 deletion examples/contrib/cifar10/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
output
cifar10
cifar10
.polyaxonignore
.polyaxon
plx_configs
gcp_configs
Binary file modified examples/contrib/cifar10/assets/tb_logger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 60 additions & 37 deletions examples/contrib/cifar10/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@
def run(output_path, config):
device = "cuda"

# Rescale batch_size and
ngpus_per_node = torch.cuda.device_count()
batch_size = config['batch_size'] // ngpus_per_node
num_workers = int((config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

local_rank = config['local_rank']

distributed = backend is not None
if distributed:
torch.cuda.set_device(config['local_rank'])
torch.cuda.set_device(local_rank)
device = "cuda"
rank = dist.get_rank() if distributed else 0

# Rescale batch_size and num_workers
ngpus_per_node = torch.cuda.device_count()
ngpus = dist.get_world_size() if distributed else 1
batch_size = config['batch_size'] // ngpus
num_workers = int((config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

train_labelled_loader, test_loader = \
get_train_test_loaders(path=config['data_path'],
Expand Down Expand Up @@ -99,29 +100,31 @@ def process_function(engine, labelled_batch):
else:
trainer.add_event_handler(Events.ITERATION_STARTED, lambda engine: lr_scheduler.step())

if local_rank == 0:
metric_names = [
'batch loss',
]

def output_transform(x, name):
return x[name]

for n in metric_names:
# We compute running average values on the output (batch loss) across all devices
RunningAverage(output_transform=partial(output_transform, name=n),
epoch_bound=False, device=device).attach(trainer, n)

if rank == 0:
checkpoint_handler = ModelCheckpoint(dirname=output_path,
filename_prefix="checkpoint",
save_interval=1000)
trainer.add_event_handler(Events.ITERATION_COMPLETED,
checkpoint_handler,
{'model': model, 'optimizer': optimizer})
metric_names = [
'batch loss',
]

def output_transform(x, name):
return x[name]

for n in metric_names:
RunningAverage(output_transform=partial(output_transform, name=n), epoch_bound=False).attach(trainer, n)

ProgressBar(persist=True, bar_format="").attach(trainer,
event_name=Events.EPOCH_STARTED,
closing_event_name=Events.COMPLETED)

ProgressBar(persist=False, bar_format="").attach(trainer, metric_names=metric_names)
if config['display_iters']:
ProgressBar(persist=False, bar_format="").attach(trainer, metric_names=metric_names)

tb_logger = TensorboardLogger(log_dir=output_path)
tb_logger.attach(trainer,
Expand All @@ -142,15 +145,17 @@ def output_transform(x, name):

def run_validation(engine, val_interval):
if engine.state.epoch % val_interval == 0:
torch.cuda.synchronize()
train_evaluator.run(train_labelled_loader)
evaluator.run(test_loader)

trainer.add_event_handler(Events.EPOCH_STARTED, run_validation, val_interval=3)
trainer.add_event_handler(Events.COMPLETED, run_validation, val_interval=1)

if local_rank == 0:
ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator)
ProgressBar(persist=False, desc="Test evaluation").attach(evaluator)
if rank == 0:
if config['display_iters']:
ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator)
ProgressBar(persist=False, desc="Test evaluation").attach(evaluator)

tb_logger.attach(train_evaluator,
log_handler=tbOutputHandler(tag="train",
Expand Down Expand Up @@ -180,6 +185,9 @@ def default_score_fn(engine):

trainer.run(train_labelled_loader, max_epochs=config['num_epochs'])

if rank == 0:
tb_logger.close()


if __name__ == "__main__":

Expand Down Expand Up @@ -220,6 +228,9 @@ def default_score_fn(engine):
# distributed settings
"dist_url": "env://",
"dist_backend": None, # if None distributed option is disabled, set to "nccl" to enable

# Logging:
"display_iters": True
}

if args.local_rank is not None:
Expand All @@ -238,7 +249,18 @@ def default_score_fn(engine):
backend = config['dist_backend']
distributed = backend is not None

if distributed:
dist.init_process_group(backend, init_method=config['dist_url'])
# let each node print the info
if config['local_rank'] == 0:
print("\nDistributed setting:")
print("\tbackend: {}".format(dist.get_backend()))
print("\tworld size: {}".format(dist.get_world_size()))
print("\trank: {}".format(dist.get_rank()))
print("\n")

output_path = None
# let each node print the info
if config['local_rank'] == 0:
print("Train {} on CIFAR10".format(network_name))
print("- PyTorch version: {}".format(torch.__version__))
Expand All @@ -251,21 +273,22 @@ def default_score_fn(engine):
print("\t{}: {}".format(key, value))
print("\n")

from datetime import datetime

now = datetime.now().strftime("%Y%m%d-%H%M%S")
gpu_conf = "-single-gpu"
if distributed:
gpu_conf = "-distributed-{}-gpus".format(torch.cuda.device_count())

output_path = Path(config['output_path']) / "{}{}".format(now, gpu_conf)
if not output_path.exists():
output_path.mkdir(parents=True)
output_path = output_path.as_posix()
print("Output path: {}".format(output_path))

if distributed:
dist.init_process_group(backend, init_method=config['dist_url'])
# create log directory only by 1 node
if (not distributed) or (dist.get_rank() == 0):
from datetime import datetime

now = datetime.now().strftime("%Y%m%d-%H%M%S")
gpu_conf = "-single-gpu"
if distributed:
ngpus_per_node = torch.cuda.device_count()
nnodes = dist.get_world_size() // ngpus_per_node
gpu_conf = "-distributed-{}nodes-{}gpus".format(nnodes, ngpus_per_node)

output_path = Path(config['output_path']) / "{}{}".format(now, gpu_conf)
if not output_path.exists():
output_path.mkdir(parents=True)
output_path = output_path.as_posix()
print("Output path: {}".format(output_path))

try:
run(output_path, config)
Expand Down
11 changes: 9 additions & 2 deletions examples/contrib/cifar10/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

Expand Down Expand Up @@ -33,14 +34,20 @@ def get_train_test_loaders(path, batch_size, num_workers, distributed=False, pin
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_ds = datasets.CIFAR10(root=path, train=True, download=True, transform=train_transform)
if not os.path.exists(path):
os.makedirs(path)
download = True
else:
download = True if len(os.listdir(path)) < 1 else False

train_ds = datasets.CIFAR10(root=path, train=True, download=download, transform=train_transform)
test_ds = datasets.CIFAR10(root=path, train=False, download=False, transform=test_transform)

train_sampler = None
test_sampler = None
if distributed:
train_sampler = DistributedSampler(train_ds)
test_sampler = DistributedSampler(test_ds)
test_sampler = DistributedSampler(test_ds, shuffle=False)

train_labelled_loader = DataLoader(train_ds, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
Expand Down
91 changes: 62 additions & 29 deletions ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,13 @@ class LRScheduler(ParamScheduler):

step_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)
scheduler = LRScheduler(step_scheduler)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# In this example, we assume to have installed PyTorch>=1.1.0
# (with new `torch.optim.lr_scheduler` behaviour) and
# we attach scheduler to Events.ITERATION_COMPLETED
# instead of Events.ITERATION_STARTED to make sure to use
# the first lr value from the optimizer, otherwise it is will be skipped:
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
"""

def __init__(self, lr_scheduler, save_history=False, **kwds):
Expand All @@ -450,7 +456,10 @@ def __call__(self, engine, name=None):
def get_param(self):
"""Method to get current optimizer's parameter value
"""
# Emulate context manager for pytorch>=1.4
self.lr_scheduler._get_lr_called_within_step = True
lr_list = self.lr_scheduler.get_lr()
self.lr_scheduler._get_lr_called_within_step = False
if len(lr_list) > 1:
raise ValueError("Optimizer passed to lr_scheduler should have a single param group, "
"but currently there are {} param groups".format(len(lr_list)))
Expand All @@ -475,8 +484,8 @@ def simulate_values(cls, num_events, lr_scheduler, **kwargs):
values = []
scheduler = cls(save_history=False, lr_scheduler=copy_lr_scheduler)
for i in range(num_events):
scheduler(engine=None)
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
scheduler(engine=None)

return values

Expand All @@ -493,7 +502,7 @@ def _replicate_lr_scheduler(lr_scheduler, new_optimizer_param_groups=None):
for group in dummy_optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
kwargs = lr_scheduler.state_dict()
for k in ['base_lrs', '_step_count']:
for k in [_k for _k in kwargs.keys() if "_" == _k[0]] + ['base_lrs', 'last_epoch']:
del kwargs[k]
copy_lr_scheduler = lr_scheduler_cls(optimizer=dummy_optimizer, **kwargs)
copy_lr_scheduler.load_state_dict(lr_scheduler.state_dict())
Expand All @@ -504,58 +513,82 @@ def create_lr_scheduler_with_warmup(lr_scheduler, warmup_start_value, warmup_end
save_history=False,
output_simulated_values=None):
"""
Helper method to create a LR scheduler with a linear warm-up.
Helper method to create a learning rate scheduler with a linear warm-up.

Args:
lr_scheduler (ParamScheduler or subclass of `torch.optim.lr_scheduler._LRScheduler`): LR scheduler after
the warm-up.
warmup_start_value (float): LR start value of the warm-up phase.
warmup_end_value (float): LR end value of the warm-up phase.
lr_scheduler (ParamScheduler or subclass of `torch.optim.lr_scheduler._LRScheduler`): learning rate scheduler
after the warm-up.
warmup_start_value (float): learning rate start value of the warm-up phase.
warmup_end_value (float): learning rate end value of the warm-up phase.
warmup_duration (int): warm-up phase duration, number of events.
save_history (bool, optional): whether to log the parameter values to
`engine.state.param_history`, (default=False).
output_simulated_values (list, optional): optional output of simulated LR values.
output_simulated_values (list, optional): optional output of simulated learning rate values.
If output_simulated_values is a list of None, e.g. `[None] * 100`, after the execution it will be filled
by 100 simulated LR values.
by 100 simulated learning rate values.

Returns:
ConcatScheduler: LR scheduler with linear warm-up.
ConcatScheduler: learning rate scheduler with linear warm-up.

Note:
If the first learning rate value provided by `lr_scheduler` is different from `warmup_end_value`, an additional
event is added after the warm-up phase such that the warm-up ends with `warmup_end_value` value and then
`lr_scheduler` provides its learning rate values as normally.

.. code-block:: python
Examples:

torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98)
lr_values = [None] * 100
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10,
output_simulated_values=lr_values)
lr_values = np.array(lr_values)
# Plot simulated values
plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
.. code-block:: python

# Attach to the trainer
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98)
lr_values = [None] * 100
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
warmup_start_value=0.0,
warmup_end_value=0.1,
warmup_duration=10,
output_simulated_values=lr_values)
lr_values = np.array(lr_values)
# Plot simulated values
plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")

# Attach to the trainer
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

"""
if not isinstance(lr_scheduler, (ParamScheduler, _LRScheduler)):
raise TypeError("Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler or "
"ParamScheduler, but given {}".format(type(lr_scheduler)))

if not (isinstance(warmup_duration, numbers.Integral) and warmup_duration > 1):
raise ValueError("Argument warmup_duration should be at least 2 events, but given {}"
.format(warmup_duration))

milestones_values = [(0, warmup_start_value), (warmup_duration - 1, warmup_end_value)]

duration_extension = 0
if isinstance(lr_scheduler, _LRScheduler):
init_lrs = [g['lr'] for g in lr_scheduler.optimizer.param_groups]
if len(init_lrs) < 1:
raise RuntimeError("Number of parameter groups of input `lr_scheduler.optimizer` is less than one.")

if init_lrs[0] != warmup_end_value:
milestones_values.append((warmup_duration, init_lrs[0]))

lr_scheduler = LRScheduler(lr_scheduler)
duration_extension = 1
else:
init_lr = lr_scheduler.get_param()
if init_lr == warmup_end_value:
if warmup_duration > 2:
d = (warmup_end_value - warmup_start_value) / (warmup_duration - 1)
milestones_values[-1] = (warmup_duration - 2, warmup_end_value - d)
else:
milestones_values.pop(-1)

dummy_optimizer = {}
warmup_scheduler = LinearCyclicalScheduler(dummy_optimizer, param_name="lr",
start_value=warmup_start_value,
end_value=warmup_end_value,
cycle_size=warmup_duration * 2)

warmup_scheduler = PiecewiseLinear(dummy_optimizer, param_name="lr", milestones_values=milestones_values)
warmup_scheduler.optimizer_param_groups = lr_scheduler.optimizer_param_groups

schedulers = [warmup_scheduler, lr_scheduler]
durations = [warmup_duration + duration_extension, ]
durations = [milestones_values[-1][0] + 1, ]
combined_scheduler = ConcatScheduler(schedulers, durations=durations,
save_history=save_history)
if output_simulated_values is not None:
Expand Down
Loading